[Twisted-Python] Updated TLS patch

Jp Calderone exarkun at intarweb.us
Thu May 1 18:33:14 EDT 2003


  I've taken Jeremy Rossi's TLS patch and updated it for current CVS, and
also cleaned up the parts of it that broke regular TCP when SSL was
unavailable.

  What I have been completely unable to do is prevent this from introducing
a branch/function call into the common path for doRead/doWrite, even when
TLS is not in use.

  In order of desirability (imho), this can be fixed by:

    Rewrite tcp.py, more or less completely, *without* juggling methods as
it currently does.

    Take the _TLS_* and _NOTLS_* functions and just inline them.

    Create a new transport, TLS, along with all the associated
methods/functions (connect/listen/etc) so as to keep TLS code out of tcp.py
entirely.

  Patch attached.

  Jp

-- 
#!/bin/bash
( LIST=(~/.sigs/*.sig)
  cat ${LIST[$(($RANDOM % ${#LIST[*]}))]}
  echo -- $'\n' `uptime | sed -e 's/.*m//'` ) > ~/.signature
-- 
 up 42 days, 19:04, 4 users, load average: 0.35, 0.16, 0.19
-------------- next part --------------
Index: twisted/internet/ssl.py
===================================================================
RCS file: /cvs/Twisted/twisted/internet/ssl.py,v
retrieving revision 1.40
diff -u -r1.40 ssl.py
--- twisted/internet/ssl.py	2 Apr 2003 04:11:32 -0000	1.40
+++ twisted/internet/ssl.py	1 May 2003 22:20:02 -0000
@@ -95,116 +95,13 @@
         return SSL.Context(SSL.SSLv3_METHOD)
 
 
-class Connection(tcp.Connection):
-    """I am an SSL connection.
-    """
-
-    __implements__ = tcp.Connection.__implements__, interfaces.ISSLTransport
-    
-    writeBlockedOnRead = 0
-    readBlockedOnWrite= 0
-    sslShutdown = 0
-    
-    def getPeerCertificate(self):
-        """Return the certificate for the peer."""
-        return self.socket.get_peer_certificate()
-
-    def _postLoseConnection(self):
-        """Gets called after loseConnection(), after buffered data is sent.
-
-        We close the SSL transport layer, and if the other side hasn't
-        closed it yet we start reading, waiting for a ZeroReturnError
-        which will indicate the SSL shutdown has completed.
-        """
-        try:
-            done = self.socket.shutdown()
-            self.sslShutdown = 1
-        except SSL.Error:
-            return main.CONNECTION_LOST
-        if done:
-            return main.CONNECTION_DONE
-        else:
-            # we wait for other side to close SSL connection -
-            # this will be signaled by SSL.ZeroReturnError when reading
-            # from the socket
-            self.stopWriting()
-            self.startReading()
-            return None # don't close socket just yet
-    
-    def doRead(self):
-        """See tcp.Connection.doRead for details.
-        """
-        if self.writeBlockedOnRead:
-            self.writeBlockedOnRead = 0
-            return self.doWrite()
-        try:
-            return tcp.Connection.doRead(self)
-        except SSL.ZeroReturnError:
-            # close SSL layer, since other side has done so, if we haven't
-            if not self.sslShutdown:
-                try:
-                    self.socket.shutdown()
-                    self.sslShutdown = 1
-                except SSL.Error:
-                    pass
-            return main.CONNECTION_DONE
-        except SSL.WantReadError:
-            return
-        except SSL.WantWriteError:
-            self.readBlockedOnWrite = 1
-            self.startWriting()
-            return
-        except SSL.Error:
-            return main.CONNECTION_LOST
-
-    def doWrite(self):
-        if self.readBlockedOnWrite:
-            self.readBlockedOnWrite = 0
-            if not self.dataBuffer: self.stopWriting()
-            return self.doRead()
-        return tcp.Connection.doWrite(self)
-    
-    def writeSomeData(self, data):
-        """See tcp.Connection.writeSomeData for details.
-        """
-        if not data:
-            return 0
-
-        try:
-            return tcp.Connection.writeSomeData(self, data)
-        except SSL.WantWriteError:
-            return 0
-        except SSL.WantReadError:
-            self.writeBlockedOnRead = 1
-            return 0
-        except SSL.Error:
-            return main.CONNECTION_LOST
-
-    def _closeSocket(self):
-        """Called to close our socket."""
-        try:
-            self.socket.sock_shutdown(2)
-        except socket.error:
-            try:
-                self.socket.close()
-            except socket.error:
-                log.deferr()
-
-
-
-class Client(Connection, tcp.Client):
+class Client(tcp.Client):
     """I am an SSL client."""
     def __init__(self, host, port, bindAddress, ctxFactory, connector, reactor=None):
         # tcp.Client.__init__ depends on self.ctxFactory being set
         self.ctxFactory = ctxFactory
         tcp.Client.__init__(self, host, port, bindAddress, connector, reactor)
 
-    def createInternetSocket(self):
-        """(internal) create an SSL socket
-        """
-        sock = tcp.Client.createInternetSocket(self)
-        return SSL.Connection(self.ctxFactory.getContext(), sock)
-
     def getHost(self):
         """Returns a tuple of ('SSL', hostname, port).
 
@@ -219,16 +116,14 @@
         """
         return ('SSL',)+self.addr
 
+    def _finishInit(self, whenDone, skt, error, reactor):
+        tcp.Client._finishInit(self, whenDone, skt, error, reactor)
+        self.startTLS(self.ctxFactory)
 
 
-class Server(Connection, tcp.Server):
+class Server(tcp.Server):
     """I am an SSL server.
     """
-    
-    def __init__(*args, **kw):
-        # We don't want Connection's __init__
-        tcp.Server.__init__(*args, **kw)
-    
     def getHost(self):
         """Returns a tuple of ('SSL', hostname, port).
 
@@ -257,33 +152,12 @@
         """
         sock = tcp.Port.createInternetSocket(self)
         return SSL.Connection(self.ctxFactory.getContext(), sock)
-    
-    def doRead(self):
-        """Called when my socket is ready for reading.
 
-        This accepts a connection and calls self.protocol() to handle the
-        wire-level protocol.
-        """
-        try:
-            try:
-                skt, addr = self.socket.accept()
-            except socket.error, e:
-                if e.args[0] == tcp.EWOULDBLOCK:
-                    return
-                raise
-            except SSL.Error:
-                log.deferr()
-                return
-            protocol = self.factory.buildProtocol(addr)
-            if protocol is None:
-                skt.close()
-                return
-            s = self.sessionno
-            self.sessionno = s+1
-            transport = self.transport(skt, protocol, addr, self, s)
-            protocol.makeConnection(transport)
-        except:
-            log.deferr()
+    def _preMakeConnection(self, transport):
+        # *Don't* call startTLS here
+        # The transport already has the SSL.Connection object from above
+        transport._startTLS()
+        return tcp.Port._preMakeConnection(self, transport)
 
 
 class Connector(base.BaseConnector):
Index: twisted/internet/tcp.py
===================================================================
RCS file: /cvs/Twisted/twisted/internet/tcp.py,v
retrieving revision 1.116
diff -u -r1.116 tcp.py
--- twisted/internet/tcp.py	1 May 2003 16:40:47 -0000	1.116
+++ twisted/internet/tcp.py	1 May 2003 22:20:03 -0000
@@ -39,6 +39,11 @@
 except ImportError:
     fcntl = None
 
+try:
+    from OpenSSL import SSL
+except ImportError:
+    SSL = None
+
 if os.name == 'nt':
     # we hardcode these since windows actually wants e.g.
     # WSAEALREADY rather than EALREADY. Possibly we should
@@ -88,12 +93,30 @@
 
     __implements__ = abstract.FileDescriptor.__implements__, interfaces.ITCPTransport
 
+    if SSL:
+        writeBlockedOnRead = 0
+        readBlockedOnWrite= 0
+        sslShutdown = 0
+        TLS = 0
+
     def __init__(self, skt, protocol, reactor=None):
         abstract.FileDescriptor.__init__(self, reactor=reactor)
         self.socket = skt
         self.socket.setblocking(0)
         self.fileno = skt.fileno
         self.protocol = protocol
+        
+    def startTLS(self, ctx):
+        if not SSL:
+            raise RuntimeException, "No SSL support available"
+        assert not self.TLS
+
+        self._startTLS()
+        self.socket = SSL.Connection(ctx.getContext(), self.socket)
+        self.fileno = self.socket.fileno
+    
+    def _startTLS(self):
+        self.TLS = 1
 
     def doRead(self):
         """Calls self.protocol.dataReceived with all available data.
@@ -103,6 +126,9 @@
         lost through an error in the physical recv(), this function will return
         the result of the dataReceived call.
         """
+        return (self._NOTLS_doRead, self._TLS_doRead)[self.TLS]()
+
+    def _NOTLS_doRead(self):
         try:
             data = self.socket.recv(self.bufferSize)
         except socket.error, se:
@@ -113,6 +139,39 @@
         if not data:
             return main.CONNECTION_LOST
         return self.protocol.dataReceived(data)
+    
+    def _TLS_doRead(self):
+        if self.writeBlockedOnRead:
+            self.writeBlockedOnRead = 0
+            return self.doWrite()
+        try:
+            return self._NOTLS_doRead()
+        except SSL.ZeroReturnError:
+            # close SSL layer, since other side has done so, if we haven't
+            if not self.sslShutdown:
+                try:
+                    self.socket.shutdown()
+                    self.sslShutdown = 1
+                except SSL.Error:
+                    pass
+            return main.CONNECTION_DONE
+        except SSL.WantReadError:
+            return
+        except SSL.WantWriteError:
+            self.readBlockedOnWrite = 1
+            self.startWriting()
+            return
+        except SSL.Error:
+            return main.CONNECTION_LOST
+
+    def _TLS_doWrite(self):
+        if self.readBlockedOnWrite:
+            self.readBlockedOnWrite = 0
+            # XXX - This is touching internal guts bad bad bad
+            if not self.dataBuffer:
+                self.stopWriting()
+            return self.doRead()
+        return abstract.FileDescriptor.doWrite(self)
 
     def writeSomeData(self, data):
         """Connection.writeSomeData(data) -> #of bytes written | CONNECTION_LOST
@@ -120,6 +179,9 @@
         the number of bytes read (which is positive) or a connection error code
         (which is negative)
         """
+        return (self._NOTLS_writeSomeData, self._TLS_writeSomeData)[self.TLS](data)
+
+    def _NOTLS_writeSomeData(self, data):
         try:
             return self.socket.send(data)
         except socket.error, se:
@@ -128,8 +190,23 @@
             else:
                 return main.CONNECTION_LOST
 
+    def _TLS_writeSomeData(self, data):
+        if not data:
+            return 0
+        try:
+            return self._NOTLS_writeSomeData(data)
+        except SSL.WantWriteError:
+            return 0
+        except SSL.WantReadError:
+            self.writeBlockedOnRead = 1
+        except SSL.Error:
+            return main.CONNECTION_LOST
+
     def _closeSocket(self):
         """Called to close our socket."""
+        return (self._NOTLS_closeSocket, self._TLS_closeSocket)[self.TLS]()
+
+    def _NOTLS_closeSocket(self):
         # This used to close() the socket, but that doesn't *really* close if
         # there's another reference to it in the TCP/IP stack, e.g. if it was
         # was inherited by a subprocess. And we really do want to close the
@@ -139,6 +216,15 @@
         except socket.error:
             pass
 
+    def _TLS_closeSocket(self):
+        try:
+            self.socket.sock_shutdown(2)
+        except:
+            try:
+                self.socket.close()
+            except:
+                pass
+
     def connectionLost(self, reason):
         """See abstract.FileDescriptor.connectionLost().
         """
@@ -173,6 +259,33 @@
 
     def setTcpNoDelay(self, enabled):
         self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, enabled)
+    
+    def _postLoseConnection(self):
+        """Gets called after loseConnection(), after buffered data is sent.
+
+        We close the SSL transport layer, and if the other side hasn't
+        closed it yet we start reading, waiting for a ZeroReturnError
+        which will indicate the SSL shutdown has completed.
+        """
+        if not self.TLS:
+            return abstract.FileDescriptor._postLoseConnection(self)
+
+        try:
+            done = self.socket.shutdown()
+            self.sslShutdown = 1
+        except SSL.Error:
+            return main.CONNECTION_LOST
+        if done:
+            return main.CONNECTION_DONE
+        else:
+            # we wait for other side to close SSL connection -
+            # this will be signaled by SSL.ZeroReturnError when reading
+            # from the socket
+            self.stopWriting()
+            self.startReading()
+            
+            # don't close socket just yet
+            return None
 
 
 class BaseClient(Connection):
@@ -191,6 +304,11 @@
         else:
             reactor.callLater(0, self.failIfNotConnected, error)
 
+    def startTLS(self, ctx):
+        holder = Connection.startTLS(self, ctx)
+        self.socket.set_connect_state()
+        return holder
+
     def stopConnecting(self):
         """Stop attempt to connect."""
         self.failIfNotConnected(error.UserError())
@@ -357,6 +475,11 @@
         """
         return self.repstr
 
+    def startTLS(self, ctx):
+        holder = Connection.startTLS(self, ctx)
+        self.socket.set_accept_state()
+        return holder
+
     def getHost(self):
         """Returns a tuple of ('INET', hostname, port).
 
@@ -455,6 +578,7 @@
                     elif e.args[0] == EPERM:
                         continue
                     raise
+                
                 protocol = self.factory.buildProtocol(addr)
                 if protocol is None:
                     skt.close()
@@ -462,11 +586,22 @@
                 s = self.sessionno
                 self.sessionno = s+1
                 transport = self.transport(skt, protocol, addr, self, s)
+                transport = self._preMakeConnection(transport)
                 protocol.makeConnection(transport)
             else:
                 self.numberAccepts = self.numberAccepts+20
         except:
+            # Note that in TLS mode, this will possibly catch SSL.Errors
+            # raised by self.socket.accept()
+            #
+            # There is no "except SSL.Error:" above because SSL may be
+            # None if there is no SSL support.  In any case, all the
+            # "except SSL.Error:" suite would probably do is log.deferr()
+            # and return, so handling it here works just as well.
             log.deferr()
+    
+    def _preMakeConnection(self, transport):
+        return transport
 
     def loseConnection(self, connDone=failure.Failure(main.CONNECTION_DONE)):
         """Stop accepting connections on this port.
Index: twisted/test/test_ssl.py
===================================================================
RCS file: /cvs/Twisted/twisted/test/test_ssl.py,v
retrieving revision 1.8
diff -u -r1.8 test_ssl.py
--- twisted/test/test_ssl.py	1 May 2003 07:02:10 -0000	1.8
+++ twisted/test/test_ssl.py	1 May 2003 22:20:03 -0000
@@ -17,19 +17,23 @@
 from __future__ import nested_scopes
 from twisted.trial import unittest
 from twisted.internet import protocol, reactor
+from twisted.protocols import basic
+
 try:
-    import OpenSSL
+    from OpenSSL import SSL
     from twisted.internet import ssl
 except ImportError:
-    OpenSSL = None
+    SSL = None
+
 import os
 import test_tcp
 
 
+certPath = os.path.join(os.path.split(test_tcp.__file__)[0], "server.pem")
+
 class StolenTCPTestCase(test_tcp.ProperlyCloseFilesTestCase, test_tcp.WriteDataTestCase):
     
     def setUp(self):
-        certPath = os.path.join(os.path.split(test_tcp.__file__)[0], "server.pem")
         f = protocol.ServerFactory()
         f.protocol = protocol.Protocol
         self.listener = reactor.listenSSL(
@@ -49,5 +53,117 @@
 
         self.totalConnections = 0
 
-if not OpenSSL:
-    del StolenTCPTestCase
+class ClientTLSContext(ssl.ClientContextFactory):
+    isClient = 1
+    def getContext(self):
+        return SSL.Context(ssl.SSL.TLSv1_METHOD)
+
+class UnintelligentProtocol(basic.LineReceiver):
+    pretext = [
+        "first line",
+        "last thing before tls starts",
+        "STARTTLS",
+    ]
+    
+    posttext = [
+        "first thing after tls started",
+        "last thing ever",
+    ]
+    
+    def connectionMade(self):
+        for l in self.pretext:
+            self.sendLine(l)
+
+    def lineReceived(self, line):
+        if line == "READY":
+            self.transport.startTLS(ClientTLSContext())
+            for l in self.posttext:
+                self.sendLine(l)
+            self.transport.loseConnection()
+        
+class ServerTLSContext(ssl.DefaultOpenSSLContextFactory):
+    isClient = 0
+    def __init__(self, *args, **kw):
+        kw['sslmethod'] = SSL.TLSv1_METHOD
+        ssl.DefaultOpenSSLContextFactory.__init__(self, *args, **kw)
+
+class LineCollector(basic.LineReceiver):
+    def __init__(self, doTLS):
+        self.doTLS = doTLS
+
+    def connectionMade(self):
+        self.factory.rawdata = ''
+        self.factory.lines = []
+
+    def lineReceived(self, line):
+        self.factory.lines.append(line)
+        if line == 'STARTTLS':
+            self.sendLine('READY')
+            if self.doTLS:
+                ctx = ServerTLSContext(
+                    privateKeyFileName=certPath,
+                    certificateFileName=certPath,
+                )
+                self.transport.startTLS(ctx)
+            else:
+                self.setRawMode()
+    
+    def rawDataReceived(self, data):
+        self.factory.rawdata += data
+        self.factory.done = 1
+    
+    def connectionLost(self, reason):
+        self.factory.done = 1
+
+class TLSTestCase(unittest.TestCase):
+    def testTLS(self):
+        cf = protocol.ClientFactory()
+        cf.protocol = UnintelligentProtocol
+        
+        sf = protocol.ServerFactory()
+        sf.protocol = lambda: LineCollector(1)
+        sf.done = 0
+
+        port = reactor.listenTCP(0, sf)
+        portNo = port.getHost()[2]
+        
+        reactor.connectTCP('0.0.0.0', portNo, cf)
+        
+        i = 0
+        while i < 5000 and not sf.done:
+            reactor.iterate(0.01)
+            i += 1
+        
+        self.failUnless(sf.done, "Never finished reading all lines")
+        self.assertEquals(
+            sf.lines,
+            UnintelligentProtocol.pretext + UnintelligentProtocol.posttext
+        )
+    
+    def testUnTLS(self):
+        cf = protocol.ClientFactory()
+        cf.protocol = UnintelligentProtocol
+        
+        sf = protocol.ServerFactory()
+        sf.protocol = lambda: LineCollector(0)
+        sf.done = 0
+
+        port = reactor.listenTCP(0, sf)
+        portNo = port.getHost()[2]
+        
+        reactor.connectTCP('0.0.0.0', portNo, cf)
+        
+        i = 0
+        while i < 5000 and not sf.done:
+            reactor.iterate(0.01)
+            i += 1
+        
+        self.failUnless(sf.done, "Never finished reading all lines")
+        self.assertEquals(
+            sf.lines,
+            UnintelligentProtocol.pretext
+        )
+        self.failUnless(sf.rawdata, "No encrypted bytes received")
+        
+if not SSL:
+    globals().clear()
-------------- next part --------------
A non-text attachment was scrubbed...
Name: echoclient_tls.py
Type: text/x-python
Size: 2253 bytes
Desc: not available
Url : http://twistedmatrix.com/pipermail/twisted-python/attachments/20030501/260763d3/attachment.py 
-------------- next part --------------
A non-text attachment was scrubbed...
Name: echoserv_tls.py
Type: text/x-python
Size: 2398 bytes
Desc: not available
Url : http://twistedmatrix.com/pipermail/twisted-python/attachments/20030501/260763d3/attachment-0001.py 
-------------- next part --------------
A non-text attachment was scrubbed...
Name: not available
Type: application/pgp-signature
Size: 189 bytes
Desc: not available
Url : http://twistedmatrix.com/pipermail/twisted-python/attachments/20030501/260763d3/attachment.pgp 


More information about the Twisted-Python mailing list