Ticket #4173: 4173-4.patch

File 4173-4.patch, 19.5 KB (added by MostAwesomeDude, 5 years ago)

Next iteration of patch, after habnabit's review

  • twisted/web/test/test_websockets.py

     
     1from twisted.trial import unittest
     2
     3from twisted.web.websockets import (make_accept, mask, CLOSE, NORMAL, PING,
     4    PONG, parse_hybi07_frames)
     5
     6class TestKeys(unittest.TestCase):
     7
     8    def test_make_accept_rfc(self):
     9        """
     10        Test ``make_accept()`` using the keys listed in the RFC for HyBi-07
     11        through HyBi-10.
     12        """
     13
     14        key = "dGhlIHNhbXBsZSBub25jZQ=="
     15
     16        self.assertEqual(make_accept(key), "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=")
     17
     18    def test_make_accept_wikipedia(self):
     19        """
     20        Test ``make_accept()`` using the keys listed on Wikipedia.
     21        """
     22
     23        key = "x3JJHMbDL1EzLkh9GBhXDw=="
     24
     25        self.assertEqual(make_accept(key), "HSmrc0sMlYUkAGmm5OPpG2HaGWk=")
     26
     27class TestHyBi07Helpers(unittest.TestCase):
     28    """
     29    HyBi-07 is best understood as a large family of helper functions which
     30    work together, somewhat dysfunctionally, to produce a mediocre
     31    Thanksgiving every other year.
     32    """
     33
     34    def test_mask_noop(self):
     35        key = "\x00\x00\x00\x00"
     36        self.assertEqual(mask("Test", key), "Test")
     37
     38    def test_mask_noop_long(self):
     39        key = "\x00\x00\x00\x00"
     40        self.assertEqual(mask("LongTest", key), "LongTest")
     41
     42    def test_parse_hybi07_unmasked_text(self):
     43        """
     44        From HyBi-10, 4.7.
     45        """
     46
     47        frame = "\x81\x05Hello"
     48        frames, buf = parse_hybi07_frames(frame)
     49        self.assertEqual(len(frames), 1)
     50        self.assertEqual(frames[0], (NORMAL, "Hello"))
     51        self.assertEqual(buf, "")
     52
     53    def test_parse_hybi07_masked_text(self):
     54        """
     55        From HyBi-10, 4.7.
     56        """
     57
     58        frame = "\x81\x857\xfa!=\x7f\x9fMQX"
     59        frames, buf = parse_hybi07_frames(frame)
     60        self.assertEqual(len(frames), 1)
     61        self.assertEqual(frames[0], (NORMAL, "Hello"))
     62        self.assertEqual(buf, "")
     63
     64    def test_parse_hybi07_unmasked_text_fragments(self):
     65        """
     66        We don't care about fragments. We are totally unfazed.
     67
     68        From HyBi-10, 4.7.
     69        """
     70
     71        frame = "\x01\x03Hel\x80\x02lo"
     72        frames, buf = parse_hybi07_frames(frame)
     73        self.assertEqual(len(frames), 2)
     74        self.assertEqual(frames[0], (NORMAL, "Hel"))
     75        self.assertEqual(frames[1], (NORMAL, "lo"))
     76        self.assertEqual(buf, "")
     77
     78    def test_parse_hybi07_ping(self):
     79        """
     80        From HyBi-10, 4.7.
     81        """
     82
     83        frame = "\x89\x05Hello"
     84        frames, buf = parse_hybi07_frames(frame)
     85        self.assertEqual(len(frames), 1)
     86        self.assertEqual(frames[0], (PING, "Hello"))
     87        self.assertEqual(buf, "")
     88
     89    def test_parse_hybi07_pong(self):
     90        """
     91        From HyBi-10, 4.7.
     92        """
     93
     94        frame = "\x8a\x05Hello"
     95        frames, buf = parse_hybi07_frames(frame)
     96        self.assertEqual(len(frames), 1)
     97        self.assertEqual(frames[0], (PONG, "Hello"))
     98        self.assertEqual(buf, "")
     99
     100    def test_parse_hybi07_close_empty(self):
     101        """
     102        A HyBi-07 close packet may have no body. In that case, it should use
     103        the generic error code 1000, and have no reason.
     104        """
     105
     106        frame = "\x88\x00"
     107        frames, buf = parse_hybi07_frames(frame)
     108        self.assertEqual(len(frames), 1)
     109        self.assertEqual(frames[0], (CLOSE, (1000, "No reason given")))
     110        self.assertEqual(buf, "")
     111
     112    def test_parse_hybi07_close_reason(self):
     113        """
     114        A HyBi-07 close packet must have its first two bytes be a numeric
     115        error code, and may optionally include trailing text explaining why
     116        the connection was closed.
     117        """
     118
     119        frame = "\x88\x0b\x03\xe8No reason"
     120        frames, buf = parse_hybi07_frames(frame)
     121        self.assertEqual(len(frames), 1)
     122        self.assertEqual(frames[0], (CLOSE, (1000, "No reason")))
     123        self.assertEqual(buf, "")
     124
     125    def test_parse_hybi07_partial_no_length(self):
     126        frame = "\x81"
     127        frames, buf = parse_hybi07_frames(frame)
     128        self.assertFalse(frames)
     129        self.assertEqual(buf, "\x81")
     130
     131    def test_parse_hybi07_partial_truncated_length_int(self):
     132        frame = "\x81\xfe"
     133        frames, buf = parse_hybi07_frames(frame)
     134        self.assertFalse(frames)
     135        self.assertEqual(buf, "\x81\xfe")
     136
     137    def test_parse_hybi07_partial_truncated_length_double(self):
     138        frame = "\x81\xff"
     139        frames, buf = parse_hybi07_frames(frame)
     140        self.assertFalse(frames)
     141        self.assertEqual(buf, "\x81\xff")
     142
     143    def test_parse_hybi07_partial_no_data(self):
     144        frame = "\x81\x05"
     145        frames, buf = parse_hybi07_frames(frame)
     146        self.assertFalse(frames)
     147        self.assertEqual(buf, "\x81\x05")
     148
     149    def test_parse_hybi07_partial_truncated_data(self):
     150        frame = "\x81\x05Hel"
     151        frames, buf = parse_hybi07_frames(frame)
     152        self.assertFalse(frames)
     153        self.assertEqual(buf, "\x81\x05Hel")
  • twisted/web/websockets.py

     
     1# Copyright (c) 2011-2012 Oregon State University Open Source Lab
     2#
     3# Permission is hereby granted, free of charge, to any person obtaining a copy
     4# of this software and associated documentation files (the "Software"), to
     5# deal in the Software without restriction, including without limitation the
     6# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
     7# sell copies of the Software, and to permit persons to whom the Software is
     8# furnished to do so, subject to the following conditions:
     9#
     10#    The above copyright notice and this permission notice shall be included
     11#    in all copies or substantial portions of the Software.
     12#
     13#    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
     14#    OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
     15#    MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN
     16#    NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
     17#    DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
     18#    OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
     19#    USE OR OTHER DEALINGS IN THE SOFTWARE.
     20
     21"""
     22The WebSockets protocol (RFC 6455), provided as a resource which wraps a
     23factory.
     24"""
     25
     26from base64 import b64encode, b64decode
     27from hashlib import sha1
     28from struct import pack, unpack
     29
     30from twisted.protocols.policies import ProtocolWrapper, WrappingFactory
     31from twisted.python import log
     32from twisted.web.error import NoResource
     33from twisted.web.resource import IResource
     34from twisted.web.server import NOT_DONE_YET
     35from zope.interface import implements
     36
     37class WSException(Exception):
     38    """
     39    Something stupid happened here.
     40
     41    If this class escapes txWS, then something stupid happened in multiple
     42    places.
     43    """
     44
     45# Control frame specifiers. Some versions of WS have control signals sent
     46# in-band. Adorable, right?
     47
     48NORMAL, CLOSE, PING, PONG = range(4)
     49
     50opcode_types = {
     51    0x0: NORMAL,
     52    0x1: NORMAL,
     53    0x2: NORMAL,
     54    0x8: CLOSE,
     55    0x9: PING,
     56    0xa: PONG,
     57}
     58
     59encoders = {
     60    "base64": b64encode,
     61}
     62
     63decoders = {
     64    "base64": b64decode,
     65}
     66
     67# Authentication for WS.
     68
     69def make_accept(key):
     70    """
     71    Create an "accept" response for a given key.
     72
     73    This dance is expected to somehow magically make WebSockets secure.
     74    """
     75
     76    guid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
     77
     78    return sha1("%s%s" % (key, guid)).digest().encode("base64").strip()
     79
     80# Frame helpers.
     81# Separated out to make unit testing a lot easier.
     82# Frames are bonghits in newer WS versions, so helpers are appreciated.
     83
     84def mask(buf, key):
     85    """
     86    Mask or unmask a buffer of bytes with a masking key.
     87
     88    The key must be exactly four bytes long.
     89    """
     90
     91    # This is super-secure, I promise~
     92    key = [ord(i) for i in key]
     93    buf = list(buf)
     94    for i, char in enumerate(buf):
     95        buf[i] = chr(ord(char) ^ key[i % 4])
     96    return "".join(buf)
     97
     98def make_hybi07_frame(buf, opcode=0x1):
     99    """
     100    Make a HyBi-07 frame.
     101
     102    This function always creates unmasked frames, and attempts to use the
     103    smallest possible lengths.
     104    """
     105
     106    if len(buf) > 0xffff:
     107        length = "\x7f%s" % pack(">Q", len(buf))
     108    elif len(buf) > 0x7d:
     109        length = "\x7e%s" % pack(">H", len(buf))
     110    else:
     111        length = chr(len(buf))
     112
     113    # Always make a normal packet.
     114    header = chr(0x80 | opcode)
     115    frame = "%s%s%s" % (header, length, buf)
     116    return frame
     117
     118def parse_hybi07_frames(buf):
     119    """
     120    Parse HyBi-07 frames in a highly compliant manner.
     121    """
     122
     123    start = 0
     124    frames = []
     125
     126    while True:
     127        # If there's not at least two bytes in the buffer, bail.
     128        if len(buf) - start < 2:
     129            break
     130
     131        # Grab the header. This single byte holds some flags nobody cares
     132        # about, and an opcode which nobody cares about.
     133        header = ord(buf[start])
     134        if header & 0x70:
     135            # At least one of the reserved flags is set. Pork chop sandwiches!
     136            raise WSException("Reserved flag in HyBi-07 frame (%d)" % header)
     137            frames.append(("", CLOSE))
     138            return frames, buf
     139
     140        # Get the opcode, and translate it to a local enum which we actually
     141        # care about.
     142        opcode = header & 0xf
     143        try:
     144            opcode = opcode_types[opcode]
     145        except KeyError:
     146            raise WSException("Unknown opcode %d in HyBi-07 frame" % opcode)
     147
     148        # Get the payload length and determine whether we need to look for an
     149        # extra length.
     150        length = ord(buf[start + 1])
     151        masked = length & 0x80
     152        length &= 0x7f
     153
     154        # The offset we're gonna be using to walk through the frame. We use
     155        # this because the offset is variable depending on the length and
     156        # mask.
     157        offset = 2
     158
     159        # Extra length fields.
     160        if length == 0x7e:
     161            if len(buf) - start < 4:
     162                break
     163
     164            length = buf[start + 2:start + 4]
     165            length = unpack(">H", length)[0]
     166            offset += 2
     167        elif length == 0x7f:
     168            if len(buf) - start < 10:
     169                break
     170
     171            # Protocol bug: The top bit of this long long *must* be cleared;
     172            # that is, it is expected to be interpreted as signed. That's
     173            # fucking stupid, if you don't mind me saying so, and so we're
     174            # interpreting it as unsigned anyway. If you wanna send exabytes
     175            # of data down the wire, then go ahead!
     176            length = buf[start + 2:start + 10]
     177            length = unpack(">Q", length)[0]
     178            offset += 8
     179
     180        if masked:
     181            if len(buf) - (start + offset) < 4:
     182                break
     183
     184            key = buf[start + offset:start + offset + 4]
     185            offset += 4
     186
     187        if len(buf) - (start + offset) < length:
     188            break
     189
     190        data = buf[start + offset:start + offset + length]
     191
     192        if masked:
     193            data = mask(data, key)
     194
     195        if opcode == CLOSE:
     196            if len(data) >= 2:
     197                # Gotta unpack the opcode and return usable data here.
     198                data = unpack(">H", data[:2])[0], data[2:]
     199            else:
     200                # No reason given; use generic data.
     201                data = 1000, "No reason given"
     202
     203        frames.append((opcode, data))
     204        start += offset + length
     205
     206    return frames, buf[start:]
     207
     208class WebSocketsProtocol(ProtocolWrapper):
     209    """
     210    Protocol which wraps another protocol to provide a WebSockets transport
     211    layer.
     212    """
     213
     214    buf = ""
     215    codec = None
     216
     217    def __init__(self, *args, **kwargs):
     218        ProtocolWrapper.__init__(self, *args, **kwargs)
     219        self.pending_frames = []
     220
     221    def parseFrames(self):
     222        """
     223        Find frames in incoming data and pass them to the underlying protocol.
     224        """
     225
     226        try:
     227            frames, self.buf = parse_hybi07_frames(self.buf)
     228        except WSException, wse:
     229            # Couldn't parse all the frames, something went wrong, let's bail.
     230            self.close(wse.args[0])
     231            return
     232
     233        for frame in frames:
     234            opcode, data = frame
     235            if opcode == NORMAL:
     236                # Business as usual. Decode the frame, if we have a decoder.
     237                if self.codec:
     238                    data = decoders[self.codec](data)
     239                # Pass the frame to the underlying protocol.
     240                ProtocolWrapper.dataReceived(self, data)
     241            elif opcode == CLOSE:
     242                # The other side wants us to close. I wonder why?
     243                reason, text = data
     244                log.msg("Closing connection: %r (%d)" % (text, reason))
     245
     246                # Close the connection.
     247                self.close()
     248            elif opcode == PING:
     249                # 5.5.2 PINGs must be responded to with PONGs.
     250                # 5.5.3 PONGs must contain the data that was sent with the
     251                # provoking PING.
     252                self.transport.write(make_hybi07_packet(data, opcode=0xa))
     253
     254    def sendFrames(self):
     255        """
     256        Send all pending frames.
     257        """
     258
     259        for frame in self.pending_frames:
     260            # Encode the frame before sending it.
     261            if self.codec:
     262                frame = encoders[self.codec](frame)
     263            packet = make_hybi07_frame(frame)
     264            self.transport.write(packet)
     265        self.pending_frames = []
     266
     267    def dataReceived(self, data):
     268        self.buf += data
     269
     270        self.parseFrames()
     271
     272        # Kick any pending frames. This is needed because frames might have
     273        # started piling up early; we can get write()s from our protocol above
     274        # when they makeConnection() immediately, before our browser client
     275        # actually sends any data. In those cases, we need to manually kick
     276        # pending frames.
     277        if self.pending_frames:
     278            self.sendFrames()
     279
     280    def write(self, data):
     281        """
     282        Write to the transport.
     283
     284        This method will only be called by the underlying protocol.
     285        """
     286
     287        self.pending_frames.append(data)
     288        self.sendFrames()
     289
     290    def writeSequence(self, data):
     291        """
     292        Write a sequence of data to the transport.
     293
     294        This method will only be called by the underlying protocol.
     295        """
     296
     297        self.pending_frames.extend(data)
     298        self.sendFrames()
     299
     300    def close(self, reason=""):
     301        """
     302        Close the connection.
     303
     304        This includes telling the other side we're closing the connection.
     305
     306        If the other side didn't signal that the connection is being closed,
     307        then we might not see their last message, but since their last message
     308        should, according to the spec, be a simple acknowledgement, it
     309        shouldn't be a problem.
     310        """
     311
     312        # Send a closing frame. It's only polite. (And might keep the browser
     313        # from hanging.)
     314        frame = make_hybi07_frame(reason, opcode=0x8)
     315        self.transport.write(frame)
     316
     317        self.loseConnection()
     318
     319class WebSocketsFactory(WrappingFactory):
     320    """
     321    Factory which wraps another factory to provide WebSockets frames for all
     322    of its protocols.
     323
     324    This factory does not provide the HTTP headers required to perform a
     325    WebSockets handshake; see C{WebSocketsResource}.
     326    """
     327
     328    protocol = WebSocketsProtocol
     329
     330class WebSocketsResource(object):
     331    """
     332    A resource for serving a protocol through WebSockets.
     333
     334    This class wraps a factory and connects it to WebSockets clients. Each
     335    connecting client will be connected to a new protocol of the factory.
     336
     337    Due to unresolved questions of logistics, this resource cannot have
     338    children.
     339    """
     340
     341    implements(IResource)
     342
     343    isLeaf = True
     344
     345    def __init__(self, factory):
     346        self._factory = WebSocketsFactory(factory)
     347
     348    def getChildWithDefault(self, name, request):
     349        return NoResource("No such child resource.")
     350
     351    def putChild(self, path, child):
     352        pass
     353
     354    def render(self, request):
     355        """
     356        Render a request.
     357
     358        We're not actually rendering a request. We are secretly going to
     359        handle a WebSockets connection instead.
     360        """
     361
     362        # If we fail at all, we're gonna fail with 400 and no response.
     363        # You might want to pop open the RFC and read along.
     364        failed = False
     365
     366        if request.method != "GET":
     367            # 4.2.1.1 GET is required.
     368            failed = True
     369
     370        upgrade = request.getHeader("Upgrade")
     371        if upgrade is None or "websocket" not in upgrade.lower():
     372            # 4.2.1.3 Upgrade: WebSocket is required.
     373            failed = True
     374
     375        connection = request.getHeader("Connection")
     376        if connection is None or "upgrade" not in connection.lower():
     377            # 4.2.1.4 Connection: Upgrade is required.
     378            failed = True
     379
     380        key = request.getHeader("Sec-WebSocket-Key")
     381        if key is None:
     382            # 4.2.1.5 The challenge key is required.
     383            failed = True
     384
     385        version = request.getHeader("Sec-WebSocket-Version")
     386        if version != "13":
     387            # 4.2.1.6 Only version 13 works.
     388            failed = True
     389            # 4.4 Forward-compatible version checking.
     390            request.setHeader("Sec-WebSocket-Version", "13")
     391
     392        # Check whether a codec is needed. WS calls this a "protocol" for
     393        # reasons I cannot fathom.
     394        codec = request.getHeader("Sec-WebSocket-Protocol")
     395
     396        if codec:
     397            if codec not in encoders or codec not in decoders:
     398                log.msg("Codec %s is not implemented" % codec)
     399                failed = True
     400
     401        if failed:
     402            request.setResponseCode(400)
     403            return ""
     404
     405        # We are going to finish this handshake. We will return a valid status
     406        # code.
     407        # 4.2.2.5.1 101 Switching Protocols
     408        request.setResponseCode(101)
     409        # 4.2.2.5.2 Upgrade: websocket
     410        request.setHeader("Upgrade", "WebSocket")
     411        # 4.2.2.5.3 Connection: Upgrade
     412        request.setHeader("Connection", "Upgrade")
     413        # 4.2.2.5.4 Response to the key challenge
     414        request.setHeader("Sec-WebSocket-Accept", make_accept(key))
     415        # 4.2.2.5.5 Optional codec declaration
     416        if codec:
     417            request.setHeader("Sec-WebSocket-Protocol", codec)
     418
     419        # Create the protocol. This could fail, in which case we deliver an
     420        # error status. Status 502 was decreed by glyph; blame him.
     421        protocol = self._factory.buildProtocol(request.transport.getPeer())
     422        if not protocol:
     423            request.setResponseCode(502)
     424            return ""
     425        if codec:
     426            protocol.codec = codec
     427
     428        # Provoke request into flushing headers and finishing the handshake.
     429        request.write("")
     430
     431        # And now take matters into our own hands. We shall manage the
     432        # transport's lifecycle.
     433        transport, request.transport = request.transport, None
     434
     435        # Connect the transport to our factory, and make things go. We need to
     436        # do some stupid stuff here; see #3204, which could fix it.
     437        transport.protocol = protocol
     438        protocol.makeConnection(transport)
     439
     440        return NOT_DONE_YET
     441
     442__all__ = ("WebSocketsResource",)