[Twisted-Python] Updated TLS patch

Skinny Puppy skin_pup-twisted at damnable.happypoo.com
Fri May 2 23:02:26 EDT 2003


Glyph Lefkowitz [glyph at twistedmatrix.com] wrote:
> -----BEGIN PGP SIGNED MESSAGE-----
> Hash: SHA1
> 
> Ugly as it is, this looks like the right answer to me...
> 
> On Friday, May 2, 2003, at 07:30 AM, Skinny Puppy wrote:
> 
> >The branch/function call can be avoided by replacing the 
> >doRead/doWrite/etc
> >methods in startTLS.  While this is still not very perty ;)
> -----BEGIN PGP SIGNATURE-----
> Version: GnuPG v1.2.1 (Darwin)
> 
> iD8DBQE+svG0vVGR4uSOE2wRAqo+AJ40/0hBnDnEh1267vYe7hAJV0TEUwCeNklv
> Qya3OyfpjxoexyNSb3iLPqc=
> =24qI
> -----END PGP SIGNATURE-----

Ok - Done - I still don't like it.  I have not run any real world tests
yet, but I have used echoserv_tls.py/echoclient_tls.py and watched the
traffic with tcpdump to verify the encryption.  And of course the Unit
Tests.  

Jeremy 
-------------- next part --------------
? doc/examples/echoclient_tls.py
? doc/examples/echoserv_tls.py
Index: twisted/internet/ssl.py
===================================================================
RCS file: /cvs/Twisted/twisted/internet/ssl.py,v
retrieving revision 1.40
diff -u -r1.40 ssl.py
--- twisted/internet/ssl.py	2 Apr 2003 04:11:32 -0000	1.40
+++ twisted/internet/ssl.py	3 May 2003 06:28:42 -0000
@@ -95,116 +95,13 @@
         return SSL.Context(SSL.SSLv3_METHOD)
 
 
-class Connection(tcp.Connection):
-    """I am an SSL connection.
-    """
-
-    __implements__ = tcp.Connection.__implements__, interfaces.ISSLTransport
-    
-    writeBlockedOnRead = 0
-    readBlockedOnWrite= 0
-    sslShutdown = 0
-    
-    def getPeerCertificate(self):
-        """Return the certificate for the peer."""
-        return self.socket.get_peer_certificate()
-
-    def _postLoseConnection(self):
-        """Gets called after loseConnection(), after buffered data is sent.
-
-        We close the SSL transport layer, and if the other side hasn't
-        closed it yet we start reading, waiting for a ZeroReturnError
-        which will indicate the SSL shutdown has completed.
-        """
-        try:
-            done = self.socket.shutdown()
-            self.sslShutdown = 1
-        except SSL.Error:
-            return main.CONNECTION_LOST
-        if done:
-            return main.CONNECTION_DONE
-        else:
-            # we wait for other side to close SSL connection -
-            # this will be signaled by SSL.ZeroReturnError when reading
-            # from the socket
-            self.stopWriting()
-            self.startReading()
-            return None # don't close socket just yet
-    
-    def doRead(self):
-        """See tcp.Connection.doRead for details.
-        """
-        if self.writeBlockedOnRead:
-            self.writeBlockedOnRead = 0
-            return self.doWrite()
-        try:
-            return tcp.Connection.doRead(self)
-        except SSL.ZeroReturnError:
-            # close SSL layer, since other side has done so, if we haven't
-            if not self.sslShutdown:
-                try:
-                    self.socket.shutdown()
-                    self.sslShutdown = 1
-                except SSL.Error:
-                    pass
-            return main.CONNECTION_DONE
-        except SSL.WantReadError:
-            return
-        except SSL.WantWriteError:
-            self.readBlockedOnWrite = 1
-            self.startWriting()
-            return
-        except SSL.Error:
-            return main.CONNECTION_LOST
-
-    def doWrite(self):
-        if self.readBlockedOnWrite:
-            self.readBlockedOnWrite = 0
-            if not self.dataBuffer: self.stopWriting()
-            return self.doRead()
-        return tcp.Connection.doWrite(self)
-    
-    def writeSomeData(self, data):
-        """See tcp.Connection.writeSomeData for details.
-        """
-        if not data:
-            return 0
-
-        try:
-            return tcp.Connection.writeSomeData(self, data)
-        except SSL.WantWriteError:
-            return 0
-        except SSL.WantReadError:
-            self.writeBlockedOnRead = 1
-            return 0
-        except SSL.Error:
-            return main.CONNECTION_LOST
-
-    def _closeSocket(self):
-        """Called to close our socket."""
-        try:
-            self.socket.sock_shutdown(2)
-        except socket.error:
-            try:
-                self.socket.close()
-            except socket.error:
-                log.deferr()
-
-
-
-class Client(Connection, tcp.Client):
+class Client(tcp.Client):
     """I am an SSL client."""
     def __init__(self, host, port, bindAddress, ctxFactory, connector, reactor=None):
         # tcp.Client.__init__ depends on self.ctxFactory being set
         self.ctxFactory = ctxFactory
         tcp.Client.__init__(self, host, port, bindAddress, connector, reactor)
 
-    def createInternetSocket(self):
-        """(internal) create an SSL socket
-        """
-        sock = tcp.Client.createInternetSocket(self)
-        return SSL.Connection(self.ctxFactory.getContext(), sock)
-
     def getHost(self):
         """Returns a tuple of ('SSL', hostname, port).
 
@@ -219,16 +116,14 @@
         """
         return ('SSL',)+self.addr
 
+    def _finishInit(self, whenDone, skt, error, reactor):
+        tcp.Client._finishInit(self, whenDone, skt, error, reactor)
+        self.startTLS(self.ctxFactory)
 
 
-class Server(Connection, tcp.Server):
+class Server(tcp.Server):
     """I am an SSL server.
     """
-    
-    def __init__(*args, **kw):
-        # We don't want Connection's __init__
-        tcp.Server.__init__(*args, **kw)
-    
     def getHost(self):
         """Returns a tuple of ('SSL', hostname, port).
 
@@ -257,33 +152,12 @@
         """
         sock = tcp.Port.createInternetSocket(self)
         return SSL.Connection(self.ctxFactory.getContext(), sock)
-    
-    def doRead(self):
-        """Called when my socket is ready for reading.
 
-        This accepts a connection and calls self.protocol() to handle the
-        wire-level protocol.
-        """
-        try:
-            try:
-                skt, addr = self.socket.accept()
-            except socket.error, e:
-                if e.args[0] == tcp.EWOULDBLOCK:
-                    return
-                raise
-            except SSL.Error:
-                log.deferr()
-                return
-            protocol = self.factory.buildProtocol(addr)
-            if protocol is None:
-                skt.close()
-                return
-            s = self.sessionno
-            self.sessionno = s+1
-            transport = self.transport(skt, protocol, addr, self, s)
-            protocol.makeConnection(transport)
-        except:
-            log.deferr()
+    def _preMakeConnection(self, transport):
+        # *Don't* call startTLS here
+        # The transport already has the SSL.Connection object from above
+        transport._startTLS()
+        return tcp.Port._preMakeConnection(self, transport)
 
 
 class Connector(base.BaseConnector):
Index: twisted/internet/tcp.py
===================================================================
RCS file: /cvs/Twisted/twisted/internet/tcp.py,v
retrieving revision 1.118
diff -u -r1.118 tcp.py
--- twisted/internet/tcp.py	2 May 2003 04:31:14 -0000	1.118
+++ twisted/internet/tcp.py	3 May 2003 06:28:50 -0000
@@ -39,6 +39,11 @@
 except ImportError:
     fcntl = None
 
+try:
+    from OpenSSL import SSL
+except ImportError:
+    SSL = None
+
 if os.name == 'nt':
     # we hardcode these since windows actually wants e.g.
     # WSAEALREADY rather than EALREADY. Possibly we should
@@ -88,14 +93,36 @@
 
     __implements__ = abstract.FileDescriptor.__implements__, interfaces.ITCPTransport
 
+    if SSL:
+        writeBlockedOnRead = 0
+        readBlockedOnWrite= 0
+        sslShutdown = 0
+        TLS = 0
+
     def __init__(self, skt, protocol, reactor=None):
         abstract.FileDescriptor.__init__(self, reactor=reactor)
         self.socket = skt
         self.socket.setblocking(0)
         self.fileno = skt.fileno
         self.protocol = protocol
+        
+    def startTLS(self, ctx):
+        if not SSL:
+            raise RuntimeException, "No SSL support available"
+        assert not self.TLS
 
-    def doRead(self):
+        self._startTLS()
+        self.socket = SSL.Connection(ctx.getContext(), self.socket)
+        self.fileno = self.socket.fileno
+    
+    def _startTLS(self):
+        self.TLS = 1
+        self.doRead = self._TLS_doRead
+        self.writeSomeData = self._TLS_writeSomeData
+        self.doWrite = self._TLS_doWrite
+        self._closeSocket = self._TLS_closeSocket
+
+    def _NOTLS_doRead(self):
         """Calls self.protocol.dataReceived with all available data.
 
         This reads up to self.bufferSize bytes of data from its socket, then
@@ -114,7 +141,42 @@
             return main.CONNECTION_LOST
         return self.protocol.dataReceived(data)
 
-    def writeSomeData(self, data):
+    doRead = _NOTLS_doRead
+    
+    def _TLS_doRead(self):
+        if self.writeBlockedOnRead:
+            self.writeBlockedOnRead = 0
+            return self.doWrite()
+        try:
+            return self._NOTLS_doRead()
+        except SSL.ZeroReturnError:
+            # close SSL layer, since other side has done so, if we haven't
+            if not self.sslShutdown:
+                try:
+                    self.socket.shutdown()
+                    self.sslShutdown = 1
+                except SSL.Error:
+                    pass
+            return main.CONNECTION_DONE
+        except SSL.WantReadError:
+            return
+        except SSL.WantWriteError:
+            self.readBlockedOnWrite = 1
+            self.startWriting()
+            return
+        except SSL.Error:
+            return main.CONNECTION_LOST
+
+    def _TLS_doWrite(self):
+        if self.readBlockedOnWrite:
+            self.readBlockedOnWrite = 0
+            # XXX - This is touching internal guts bad bad bad
+            if not self.dataBuffer:
+                self.stopWriting()
+            return self.doRead()
+        return abstract.FileDescriptor.doWrite(self)
+
+    def _NOTLS_writeSomeData(self, data):
         """Connection.writeSomeData(data) -> #of bytes written | CONNECTION_LOST
         This writes as much data as possible to the socket and returns either
         the number of bytes read (which is positive) or a connection error code
@@ -128,7 +190,21 @@
             else:
                 return main.CONNECTION_LOST
 
-    def _closeSocket(self):
+    writeSomeData = _NOTLS_writeSomeData
+
+    def _TLS_writeSomeData(self, data):
+        if not data:
+            return 0
+        try:
+            return self._NOTLS_writeSomeData(data)
+        except SSL.WantWriteError:
+            return 0
+        except SSL.WantReadError:
+            self.writeBlockedOnRead = 1
+        except SSL.Error:
+            return main.CONNECTION_LOST
+
+    def _NOTLS_closeSocket(self):
         """Called to close our socket."""
         # This used to close() the socket, but that doesn't *really* close if
         # there's another reference to it in the TCP/IP stack, e.g. if it was
@@ -139,6 +215,17 @@
         except socket.error:
             pass
 
+    _closeSocket = _NOTLS_closeSocket
+
+    def _TLS_closeSocket(self):
+        try:
+            self.socket.sock_shutdown(2)
+        except:
+            try:
+                self.socket.close()
+            except:
+                pass
+
     def connectionLost(self, reason):
         """See abstract.FileDescriptor.connectionLost().
         """
@@ -173,6 +260,33 @@
 
     def setTcpNoDelay(self, enabled):
         self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, enabled)
+    
+    def _postLoseConnection(self):
+        """Gets called after loseConnection(), after buffered data is sent.
+
+        We close the SSL transport layer, and if the other side hasn't
+        closed it yet we start reading, waiting for a ZeroReturnError
+        which will indicate the SSL shutdown has completed.
+        """
+        if not self.TLS:
+            return abstract.FileDescriptor._postLoseConnection(self)
+
+        try:
+            done = self.socket.shutdown()
+            self.sslShutdown = 1
+        except SSL.Error:
+            return main.CONNECTION_LOST
+        if done:
+            return main.CONNECTION_DONE
+        else:
+            # we wait for other side to close SSL connection -
+            # this will be signaled by SSL.ZeroReturnError when reading
+            # from the socket
+            self.stopWriting()
+            self.startReading()
+            
+            # don't close socket just yet
+            return None
 
 
 class BaseClient(Connection):
@@ -191,6 +305,11 @@
         else:
             reactor.callLater(0, self.failIfNotConnected, error)
 
+    def startTLS(self, ctx):
+        holder = Connection.startTLS(self, ctx)
+        self.socket.set_connect_state()
+        return holder
+
     def stopConnecting(self):
         """Stop attempt to connect."""
         self.failIfNotConnected(error.UserError())
@@ -360,6 +479,11 @@
         """
         return self.repstr
 
+    def startTLS(self, ctx):
+        holder = Connection.startTLS(self, ctx)
+        self.socket.set_accept_state()
+        return holder
+
     def getHost(self):
         """Returns a tuple of ('INET', hostname, port).
 
@@ -458,6 +582,7 @@
                     elif e.args[0] == EPERM:
                         continue
                     raise
+                
                 protocol = self.factory.buildProtocol(addr)
                 if protocol is None:
                     skt.close()
@@ -465,11 +590,22 @@
                 s = self.sessionno
                 self.sessionno = s+1
                 transport = self.transport(skt, protocol, addr, self, s)
+                transport = self._preMakeConnection(transport)
                 protocol.makeConnection(transport)
             else:
                 self.numberAccepts = self.numberAccepts+20
         except:
+            # Note that in TLS mode, this will possibly catch SSL.Errors
+            # raised by self.socket.accept()
+            #
+            # There is no "except SSL.Error:" above because SSL may be
+            # None if there is no SSL support.  In any case, all the
+            # "except SSL.Error:" suite would probably do is log.deferr()
+            # and return, so handling it here works just as well.
             log.deferr()
+    
+    def _preMakeConnection(self, transport):
+        return transport
 
     def loseConnection(self, connDone=failure.Failure(main.CONNECTION_DONE)):
         """Stop accepting connections on this port.
Index: twisted/test/test_ssl.py
===================================================================
RCS file: /cvs/Twisted/twisted/test/test_ssl.py,v
retrieving revision 1.9
diff -u -r1.9 test_ssl.py
--- twisted/test/test_ssl.py	3 May 2003 02:03:54 -0000	1.9
+++ twisted/test/test_ssl.py	3 May 2003 06:28:55 -0000
@@ -17,19 +17,23 @@
 from __future__ import nested_scopes
 from twisted.trial import unittest
 from twisted.internet import protocol, reactor
+from twisted.protocols import basic
+
 try:
-    import OpenSSL
+    from OpenSSL import SSL
     from twisted.internet import ssl
 except ImportError:
-    OpenSSL = None
+    SSL = None
+
 import os
 import test_tcp
 
 
+certPath = os.path.join(os.path.split(test_tcp.__file__)[0], "server.pem")
+
 class StolenTCPTestCase(test_tcp.ProperlyCloseFilesTestCase, test_tcp.WriteDataTestCase):
     
     def setUp(self):
-        certPath = os.path.join(os.path.split(test_tcp.__file__)[0], "server.pem")
         f = protocol.ServerFactory()
         f.protocol = protocol.Protocol
         self.listener = reactor.listenSSL(
@@ -49,5 +53,117 @@
 
         self.totalConnections = 0
 
-if not OpenSSL:
-    del StolenTCPTestCase
+class ClientTLSContext(ssl.ClientContextFactory):
+    isClient = 1
+    def getContext(self):
+        return SSL.Context(ssl.SSL.TLSv1_METHOD)
+
+class UnintelligentProtocol(basic.LineReceiver):
+    pretext = [
+        "first line",
+        "last thing before tls starts",
+        "STARTTLS",
+    ]
+    
+    posttext = [
+        "first thing after tls started",
+        "last thing ever",
+    ]
+    
+    def connectionMade(self):
+        for l in self.pretext:
+            self.sendLine(l)
+
+    def lineReceived(self, line):
+        if line == "READY":
+            self.transport.startTLS(ClientTLSContext())
+            for l in self.posttext:
+                self.sendLine(l)
+            self.transport.loseConnection()
+        
+class ServerTLSContext(ssl.DefaultOpenSSLContextFactory):
+    isClient = 0
+    def __init__(self, *args, **kw):
+        kw['sslmethod'] = SSL.TLSv1_METHOD
+        ssl.DefaultOpenSSLContextFactory.__init__(self, *args, **kw)
+
+class LineCollector(basic.LineReceiver):
+    def __init__(self, doTLS):
+        self.doTLS = doTLS
+
+    def connectionMade(self):
+        self.factory.rawdata = ''
+        self.factory.lines = []
+
+    def lineReceived(self, line):
+        self.factory.lines.append(line)
+        if line == 'STARTTLS':
+            self.sendLine('READY')
+            if self.doTLS:
+                ctx = ServerTLSContext(
+                    privateKeyFileName=certPath,
+                    certificateFileName=certPath,
+                )
+                self.transport.startTLS(ctx)
+            else:
+                self.setRawMode()
+    
+    def rawDataReceived(self, data):
+        self.factory.rawdata += data
+        self.factory.done = 1
+    
+    def connectionLost(self, reason):
+        self.factory.done = 1
+
+class TLSTestCase(unittest.TestCase):
+    def testTLS(self):
+        cf = protocol.ClientFactory()
+        cf.protocol = UnintelligentProtocol
+        
+        sf = protocol.ServerFactory()
+        sf.protocol = lambda: LineCollector(1)
+        sf.done = 0
+
+        port = reactor.listenTCP(0, sf)
+        portNo = port.getHost()[2]
+        
+        reactor.connectTCP('0.0.0.0', portNo, cf)
+        
+        i = 0
+        while i < 5000 and not sf.done:
+            reactor.iterate(0.01)
+            i += 1
+        
+        self.failUnless(sf.done, "Never finished reading all lines")
+        self.assertEquals(
+            sf.lines,
+            UnintelligentProtocol.pretext + UnintelligentProtocol.posttext
+        )
+    
+    def testUnTLS(self):
+        cf = protocol.ClientFactory()
+        cf.protocol = UnintelligentProtocol
+        
+        sf = protocol.ServerFactory()
+        sf.protocol = lambda: LineCollector(0)
+        sf.done = 0
+
+        port = reactor.listenTCP(0, sf)
+        portNo = port.getHost()[2]
+        
+        reactor.connectTCP('0.0.0.0', portNo, cf)
+        
+        i = 0
+        while i < 5000 and not sf.done:
+            reactor.iterate(0.01)
+            i += 1
+        
+        self.failUnless(sf.done, "Never finished reading all lines")
+        self.assertEquals(
+            sf.lines,
+            UnintelligentProtocol.pretext
+        )
+        self.failUnless(sf.rawdata, "No encrypted bytes received")
+        
+if not SSL:
+    globals().clear()


More information about the Twisted-Python mailing list