Index: twisted/web/websocket.py
===================================================================
--- twisted/web/websocket.py	(revision 29410)
+++ twisted/web/websocket.py	(working copy)
@@ -2,6 +2,8 @@
 # Copyright (c) 2009 Twisted Matrix Laboratories.
 # See LICENSE for details.
 
+# Based on http://twistedmatrix.com/trac/browser/branches/websocket-4173-2/twisted/web/websocket.py
+
 """
 WebSocket server protocol.
 
@@ -11,20 +13,25 @@
 @since: 10.1
 """
 
-
 from twisted.web.http import datetimeToString
 from twisted.web.server import Request, Site, version, unquote
+import struct
+import re
+import hashlib
 
 
-
 class WebSocketRequest(Request):
     """
     A general purpose L{Request} supporting connection upgrade for WebSocket.
     """
+    handlerFactory = None
 
+    def isWebSocket(self):
+        return self.requestHeaders.getRawHeaders("Upgrade") == ["WebSocket"] and \
+            self.requestHeaders.getRawHeaders("Connection") == ["Upgrade"]
+
     def process(self):
-        if (self.requestHeaders.getRawHeaders("Upgrade") == ["WebSocket"] and
-            self.requestHeaders.getRawHeaders("Connection") == ["Upgrade"]):
+        if self.isWebSocket():
             return self.processWebSocket()
         else:
             return Request.process(self)
@@ -37,6 +44,9 @@
         # get site from channel
         self.site = self.channel.site
 
+        # set an empty handler attribute
+        self.handler = None
+
         # set various default headers
         self.setHeader("server", version)
         self.setHeader("date", datetimeToString())
@@ -47,45 +57,78 @@
         self.renderWebSocket()
 
 
-    def _checkClientHandshake(self):
-        """
-        Verify client handshake, closing the connection in case of problem.
+    def _handshake75(self):
+        origin  = self.requestHeaders.getRawHeaders("Origin",   [None])[0]
+        host    = self.requestHeaders.getRawHeaders("Host",     [None])[0]
+        if not origin or not host:
+            return
+        
+        protocol = self.requestHeaders.getRawHeaders("WebSocket-Protocol", [None])[0]
+        if protocol and protocol not in self.site.supportedProtocols:
+            return
+            
+        if self.isSecure():
+            scheme = "wss"
+        else:
+            scheme = "ws"
+        location = "%s://%s%s" % (scheme, host, self.uri)
+        handshake = [
+            "HTTP/1.1 101 Web Socket Protocol Handshake",
+            "Upgrade: WebSocket",
+            "Connection: Upgrade",
+            "WebSocket-Origin: %s" % origin,
+            "WebSocket-Location: %s" % location,
+            ]
+        if protocol is not None:
+            handshake.append("WebSocket-Protocol: %s" % protocol)
+                
+        return handshake
+    
+    def _handshake76(self):
+        origin  = self.requestHeaders.getRawHeaders("Origin",   [None])[0]
+        host    = self.requestHeaders.getRawHeaders("Host",     [None])[0]
+        if not origin or not host:
+            return None, None
+        
+        protocol = self.requestHeaders.getRawHeaders("Sec-WebSocket-Protocol", [None])[0]
+        if protocol and protocol not in self.site.supportedProtocols:
+            return None, None
 
-        @return: C{None} if a problem was detected, or a tuple of I{Origin}
-            header, I{Host} header, I{WebSocket-Protocol} header, and
-            C{WebSocketHandler} instance. The I{WebSocket-Protocol} header will
-            be C{None} if not specified by the client.
-        """
-        def finish():
-            self.channel.transport.loseConnection()
-        if self.queued:
-            return finish()
-        originHeaders = self.requestHeaders.getRawHeaders("Origin", [])
-        if len(originHeaders) != 1:
-            return finish()
-        hostHeaders = self.requestHeaders.getRawHeaders("Host", [])
-        if len(hostHeaders) != 1:
-            return finish()
-
-        handlerFactory = self.site.handlers.get(self.uri)
-        if not handlerFactory:
-            return finish()
-        transport = WebSocketTransport(self)
-        handler = handlerFactory(transport)
-        transport._attachHandler(handler)
-
-        protocolHeaders = self.requestHeaders.getRawHeaders(
-            "WebSocket-Protocol", [])
-        if len(protocolHeaders) not in (0,  1):
-            return finish()
-        if protocolHeaders:
-            if protocolHeaders[0] not in self.site.supportedProtocols:
-                return finish()
-            protocolHeader = protocolHeaders[0]
+        if self.isSecure():
+            scheme = "wss"
         else:
-            protocolHeader = None
-        return originHeaders[0], hostHeaders[0], protocolHeader, handler
+            scheme = "ws"
+        location = "%s://%s%s" % (scheme, host, self.uri)
+        handshake = [
+            "HTTP/1.1 101 Web Socket Protocol Handshake",
+            "Upgrade: WebSocket",
+            "Connection: Upgrade",
+            "Sec-WebSocket-Origin: %s" % origin,
+            "Sec-WebSocket-Location: %s" % location,
+            ]
+        if protocol is not None:
+            handshake.append("Sec-WebSocket-Protocol: %s" % protocol)
+        
+        self.channel.setRawMode() 
+        
+        # Refer to 5.2 4-9 of the draft 76
+        key1 = self.requestHeaders.getRawHeaders('Sec-WebSocket-Key1', [None])[0]
+        key2 = self.requestHeaders.getRawHeaders('Sec-WebSocket-Key2', [None])[0]
+        key3 = self.content.getvalue()
+        
+        def extract_nums(s): return int(''.join(re.findall(r'[0-9]', s)))
+        def count_spaces(s): return len(re.findall(r' ', s))
+        part1 = extract_nums(key1) / count_spaces(key1)
+        part2 = extract_nums(key2) / count_spaces(key2)
+        challenge = hashlib.md5(struct.pack('>ii8s', part1, part2, key3)).digest()
+        
+        return handshake, challenge
 
+    def gotLength(self, length):
+        spec76 = self.requestHeaders.getRawHeaders("Sec-WebSocket-Key1", [None])[0]
+        if self.isWebSocket() and spec76:
+            self.channel.headerReceived("content-length: 8")
+        Request.gotLength(self, length)
 
     def renderWebSocket(self):
         """
@@ -95,31 +138,37 @@
         connection will be closed. Otherwise, the response to the handshake is
         sent and a C{WebSocketHandler} is created to handle the request.
         """
-        check =  self._checkClientHandshake()
-        if check is None:
+        if self.queued:
+            self.channel.transport.loseConnection()
             return
-        originHeader, hostHeader, protocolHeader, handler = check
-        self.startedWriting = True
-        handshake = [
-            "HTTP/1.1 101 Web Socket Protocol Handshake",
-            "Upgrade: WebSocket",
-            "Connection: Upgrade"]
-        handshake.append("WebSocket-Origin: %s" % (originHeader))
-        if self.isSecure():
-            scheme = "wss"
+        
+        if self.requestHeaders.getRawHeaders("Sec-WebSocket-Key1", [None])[0]:
+            handshake, challenge_response = self._handshake76()
         else:
-            scheme = "ws"
-        handshake.append(
-            "WebSocket-Location: %s://%s%s" % (
-            scheme, hostHeader, self.uri))
+            handshake = self._handshake75()
+            challenge_response = None
+        
+        if not handshake:
+            self.channel.transport.loseConnection()
+            return
 
-        if protocolHeader is not None:
-            handshake.append("WebSocket-Protocol: %s" % protocolHeader)
+        handlerFactory = self.site.handlers.get(self.uri) or self.handlerFactory
+        if not handlerFactory:
+            return self.channel.transport.loseConnection()
+        transport = WebSocketTransport(self)
+        handler = handlerFactory(transport)
+        transport._attachHandler(handler)
+        self.handler = handler
 
+        self.startedWriting = True
+        
         for header in handshake:
             self.write("%s\r\n" % header)
 
         self.write("\r\n")
+        if challenge_response:
+            self.write(challenge_response)
+        
         self.channel.setRawMode()
         # XXX we probably don't want to set _transferDecoder
         self.channel._transferDecoder = WebSocketFrameDecoder(
