[Twisted-Python] TLS support for twisted (PATCH)

Skinny Puppy skin_pup-twisted at damnable.happypoo.com
Wed Feb 5 20:16:32 EST 2003


Attached is the patch for TLS support within twisted.  It is not the best
code,  but it passes all the SSL test (None working on that next) and
works.  The diff also adds an echoclient_tls.py/echoserv_tls.py, that
work along with the echoclient_ssl.py/echoserv_tls.py

Some things I did have questions about how best to handle:

- What error should starttls raise if OpenSSL is not installed?  exarkun
  suggested the following but I have no idea how to do it cleanly.  Any
  ideas?

exarkun : currently, connect/listenSSL aren't defined at all if
          the ssl support libs can't be implemented
exarkun : might it make sense to just duplicate that?


- the doRead_TLS/doRead_NOTLS and writeSomeData_TLS/writeSomeData_NOTLS
  are just damn ugly and could be combined but I felt it was best to leave
  the apart for now.  Is OK, should I merge them?

Jeremy Rossi 


-------------- next part --------------
diff -urP Twisted-1.0.2alpha4/doc/examples/echoclient_tls.py Twisted-1.0.2alpha4-halfass-branch/doc/examples/echoclient_tls.py
--- Twisted-1.0.2alpha4/doc/examples/echoclient_tls.py	Wed Dec 31 19:00:00 1969
+++ Twisted-1.0.2alpha4-halfass-branch/doc/examples/echoclient_tls.py	Wed Feb  5 19:21:06 2003
@@ -0,0 +1,65 @@
+#!/usr/bin/python
+# Twisted, the Framework of Your Internet
+# Copyright (C) 2001 Matthew W. Lefkowitz
+# 
+# This library is free software; you can redistribute it and/or
+# modify it under the terms of version 2.1 of the GNU Lesser General Public
+# License as published by the Free Software Foundation.
+# 
+# This library is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+# Lesser General Public License for more details.
+# 
+# You should have received a copy of the GNU Lesser General Public
+# License along with this library; if not, write to the Free Software
+# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
+
+from twisted.internet.protocol import ClientFactory, Protocol
+from twisted.internet.app import Application
+from twisted.internet import reactor, tcp, ssl
+import sys
+
+class myContext(ssl.ClientContextFactory):
+    isClient = 1
+    def getContext(self):
+        return ssl.SSL.Context(ssl.SSL.TLSv1_METHOD)
+
+x = myContext()
+
+class EchoClient(Protocol):
+    end="Bye-bye!"
+    def connectionMade(self):
+        self.transport.write("I am sending this in the clear\n")
+        self.transport.write("And why should I not?\n")
+        self.transport.write("STARTTLS;\n")
+
+    def dataReceived(self, data):
+     	for i in data.split("\n"):
+            try:
+                command, other = i.split(";", 1)
+            except:
+                command = ""
+                other = i
+            if command==self.end:
+                self.transport.loseConnection()
+            elif command=="READY":
+                self.transport.starttls(x)
+                self.transport.write("Spooks cannot see me now.\n")
+            else:
+                print i
+
+class EchoClientFactory(ClientFactory):
+    protocol = EchoClient
+
+    def clientConnectionFailed(self, connector, reason):
+        print 'connection failed:', reason.getErrorMessage()
+        reactor.stop()
+
+    def clientConnectionLost(self, connector, reason):
+        print 'connection lost:', reason.getErrorMessage()
+        reactor.stop()
+
+factory = EchoClientFactory()
+reactor.connectTCP('localhost', 8000, factory)
+reactor.run()
diff -urP Twisted-1.0.2alpha4/doc/examples/echoserv_tls.py Twisted-1.0.2alpha4-halfass-branch/doc/examples/echoserv_tls.py
--- Twisted-1.0.2alpha4/doc/examples/echoserv_tls.py	Wed Dec 31 19:00:00 1969
+++ Twisted-1.0.2alpha4-halfass-branch/doc/examples/echoserv_tls.py	Wed Feb  5 19:20:55 2003
@@ -0,0 +1,71 @@
+
+# Twisted, the Framework of Your Internet
+# Copyright (C) 2001 Matthew W. Lefkowitz
+# 
+# This library is free software; you can redistribute it and/or
+# modify it under the terms of version 2.1 of the GNU Lesser General Public
+# License as published by the Free Software Foundation.
+# 
+# This library is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+# Lesser General Public License for more details.
+# 
+# You should have received a copy of the GNU Lesser General Public
+# License along with this library; if not, write to the Free Software
+# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
+
+from twisted.internet.protocol import Protocol, Factory
+from twisted.internet import udp, ssl
+
+### Protocol Implementation
+
+# This is just about the simplest possible protocol
+
+x = ssl.DefaultOpenSSLContextFactory(privateKeyFileName="server.pem",
+                                 certificateFileName="server.pem",
+                                 sslmethod=ssl.SSL.TLSv1_METHOD)
+
+
+class Echo(Protocol):
+    def dataReceived(self, data):
+        "As soon as any data is received, write it back."
+        print data
+        try:
+            command, other = data.split(";", 1)
+        except:
+            command = data
+            other = ""
+        if command == "STARTTLS":
+            print "starting TLS"
+            self.transport.write("READY;ajshdakjsd\n")
+            self.transport.starttls(x)
+#            self.transport.socket.set_accept_state()
+        else:
+            self.transport.write(data)
+
+
+### Persistent Application Builder
+
+# This builds a .tap file
+class EchoClientFactory(Factory):
+    protocol = Echo
+    def connectionFailed(self, connector, reason):
+        print 'connection failed:', reason.getErrorMessage()
+
+    def connectionLost(self, connector, reason):
+        print 'connection lost:', reason.getErrorMessage()
+
+
+if __name__ == '__main__':
+    # Since this is persistent, it's important to get the module naming right
+    # (If we just used Echo, then it would be __main__.Echo when it attempted
+    # to unpickle)
+    import echoserv_tls
+    from twisted.internet.app import Application
+    factory = echoserv_tls.EchoClientFactory()
+    factory.protocol = echoserv_tls.Echo
+    app = Application("echo-tls")
+    app.listenTCP(8000,factory)
+#    app.listenUDP(8000, factory)
+    app.save("start")
diff -urP Twisted-1.0.2alpha4/twisted/internet/ssl.py Twisted-1.0.2alpha4-halfass-branch/twisted/internet/ssl.py
--- Twisted-1.0.2alpha4/twisted/internet/ssl.py	Wed Jan  8 09:18:53 2003
+++ Twisted-1.0.2alpha4-halfass-branch/twisted/internet/ssl.py	Wed Feb  5 19:20:06 2003
@@ -40,11 +40,11 @@
 import socket
 
 # sibling imports
-import tcp, main, interfaces
+import main, interfaces, tcp
 
 # Twisted imports
 from twisted.python import log
-
+#from twisted.internet import reactor
 
 class ContextFactory:
     """A factory for SSL context objects, for server SSL connections."""
@@ -95,101 +95,7 @@
         return SSL.Context(SSL.SSLv3_METHOD)
 
 
-class Connection(tcp.Connection):
-    """I am an SSL connection.
-    """
-
-    __implements__ = tcp.Connection.__implements__, interfaces.ISSLTransport
-    
-    writeBlockedOnRead = 0
-    readBlockedOnWrite= 0
-    sslShutdown = 0
-    
-    def getPeerCertificate(self):
-        """Return the certificate for the peer."""
-        return self.socket.get_peer_certificate()
-
-    def _postLoseConnection(self):
-        """Gets called after loseConnection(), after buffered data is sent.
-
-        We close the SSL transport layer, and if the other side hasn't
-        closed it yet we start reading, waiting for a ZeroReturnError
-        which will indicate the SSL shutdown has completed.
-        """
-        try:
-            done = self.socket.shutdown()
-            self.sslShutdown = 1
-        except SSL.Error:
-            return main.CONNECTION_LOST
-        if done:
-            return main.CONNECTION_DONE
-        else:
-            # we wait for other side to close SSL connection -
-            # this will be signaled by SSL.ZeroReturnError when reading
-            # from the socket
-            self.stopWriting()
-            self.startReading()
-            return None # don't close socket just yet
-    
-    def doRead(self):
-        """See tcp.Connection.doRead for details.
-        """
-        if self.writeBlockedOnRead:
-            self.writeBlockedOnRead = 0
-            return self.doWrite()
-        try:
-            return tcp.Connection.doRead(self)
-        except SSL.ZeroReturnError:
-            # close SSL layer, since other side has done so, if we haven't
-            if not self.sslShutdown:
-                try:
-                    self.socket.shutdown()
-                    self.sslShutdown = 1
-                except SSL.Error:
-                    pass
-            return main.CONNECTION_DONE
-        except SSL.WantReadError:
-            return
-        except SSL.WantWriteError:
-            self.readBlockedOnWrite = 1
-            self.startWriting()
-            return
-        except SSL.Error:
-            return main.CONNECTION_LOST
-
-    def doWrite(self):
-        if self.readBlockedOnWrite:
-            self.readBlockedOnWrite = 0
-            if not self.unsent: self.stopWriting()
-            return self.doRead()
-        return tcp.Connection.doWrite(self)
-    
-    def writeSomeData(self, data):
-        """See tcp.Connection.writeSomeData for details.
-        """
-        if not data:
-            return 0
-
-        try:
-            return tcp.Connection.writeSomeData(self, data)
-        except SSL.WantWriteError:
-            return 0
-        except SSL.WantReadError:
-            self.writeBlockedOnRead = 1
-            return 0
-        except SSL.Error:
-            return main.CONNECTION_LOST
-
-    def _closeSocket(self):
-        """Called to close our socket."""
-        try:
-            self.socket.sock_shutdown(2)
-        except socket.error:
-            pass
-
-
-
-class Client(Connection, tcp.TCPClient):
+class Client(tcp.TCPClient):
     """I am an SSL client.
     """
     
@@ -197,11 +103,15 @@
         self.ctxFactory = ctxFactory
         tcp.TCPClient.__init__(self, host, port, bindAddress, connector, reactor)
     
+    def _finishInit(self, whenDone, skt, error, reactor):
+        tcp.TCPClient._finishInit(self, whenDone, skt, error, reactor)
+        self.starttls(self.ctxFactory)
+
     def createInternetSocket(self):
         """(internal) create an SSL socket
         """
         sock = tcp.TCPClient.createInternetSocket(self)
-        return SSL.Connection(self.ctxFactory.getContext(), sock)
+        return sock
 
     def getHost(self):
         """Returns a tuple of ('SSL', hostname, port).
@@ -220,13 +130,14 @@
 
 
 
-class Server(Connection, tcp.Server):
+class Server(tcp.Connection, tcp.Server):
     """I am an SSL server.
     """
     
     def __init__(*args, **kwargs):
         # we need those so we don't use ssl.Connection's __init__
         apply(tcp.Server.__init__, args, kwargs)
+        
 
     def getHost(self):
         """Returns a tuple of ('SSL', hostname, port).
@@ -258,7 +169,8 @@
         """(internal) create an SSL socket
         """
         sock = tcp.Port.createInternetSocket(self)
-        return SSL.Connection(self.ctxFactory.getContext(), sock)
+        sock = SSL.Connection(self.ctxFactory.getContext(), sock)
+        return sock
     
     def doRead(self):
         """Called when my socket is ready for reading.
@@ -283,6 +195,7 @@
             s = self.sessionno
             self.sessionno = s+1
             transport = self.transport(skt, protocol, addr, self, s)
+            transport.TLS = 1
             protocol.makeConnection(transport)
         except:
             log.deferr()
diff -urP Twisted-1.0.2alpha4/twisted/internet/tcp.py Twisted-1.0.2alpha4-halfass-branch/twisted/internet/tcp.py
--- Twisted-1.0.2alpha4/twisted/internet/tcp.py	Wed Jan  1 09:32:27 2003
+++ Twisted-1.0.2alpha4-halfass-branch/twisted/internet/tcp.py	Wed Feb  5 19:20:02 2003
@@ -34,6 +34,10 @@
 import select
 import operator
 import warnings
+try:
+    from OpenSSL import SSL
+except:
+    SSL = None
 
 if os.name == 'nt':
     EINVAL      = 10022
@@ -73,6 +77,10 @@
     This is an abstract superclass of all objects which represent a TCP/IP
     connection based socket.
     """
+    writeBlockedOnRead = 0
+    readBlockedOnWrite= 0
+    sslShutdown = 0
+    TLS = 0
 
     __implements__ = abstract.FileDescriptor.__implements__, interfaces.ITCPTransport
 
@@ -83,7 +91,47 @@
         self.fileno = skt.fileno
         self.protocol = protocol
 
+    def starttls(self, ctx):
+        self._oldsocket = self.socket
+        self.socket = SSL.Connection(ctx.getContext(), self.socket)
+        self.fileno = self.socket.fileno
+        self.TLS = 1
+
+    def stoptls(self):
+        self.socket = self._oldsocket
+        self.TLS = 0
+
     def doRead(self):
+        if self.TLS:
+            return self.doRead_TLS()
+        else:
+            return self.doRead_NOTLS()
+
+    def doRead_TLS(self):
+        if self.writeBlockedOnRead:
+            self.writeBlockedOnRead = 0
+            return self.doWrite()
+        try:
+            return self.doRead_NOTLS()
+        except SSL.ZeroReturnError:
+            # close SSL layer, since other side has done so, if we haven't
+            if not self.sslShutdown:
+                try:
+                    self.socket.shutdown()
+                    self.sslShutdown = 1
+                except SSL.Error:
+                    pass
+            return main.CONNECTION_DONE
+        except SSL.WantReadError:
+            return
+        except SSL.WantWriteError:
+            self.readBlockedOnWrite = 1
+            self.startWriting()
+            return
+        except SSL.Error:
+            return main.CONNECTION_LOST
+
+    def doRead_NOTLS(self):
         """Calls self.protocol.dataReceived with all available data.
 
         This reads up to self.bufferSize bytes of data from its socket, then
@@ -102,7 +150,36 @@
             return main.CONNECTION_LOST
         return self.protocol.dataReceived(data)
 
+    
+    def doWrite(self):
+        if self.TLS:
+            if self.readBlockedOnWrite:
+                self.readBlockedOnWrite = 0
+                if not self.unsent: self.stopWriting()
+                return self.doRead()
+        return abstract.FileDescriptor.doWrite(self)
+
     def writeSomeData(self, data):
+        if self.TLS:
+            return self.writeSomeData_TLS(data)
+        else:
+            return self.writeSomeData_NOTLS(data)
+
+
+    def writeSomeData_TLS(self, data):
+        if not data:
+            return 0
+        try:
+            return self.writeSomeData_NOTLS(data)
+        except SSL.WantWriteError:
+            return 0
+        except SSL.WantReadError:
+            self.writeBlockedOnRead = 1
+            return 0
+        except SSL.Error:
+            return main.CONNECTION_LOST
+
+    def writeSomeData_NOTLS(self, data):
         """Connection.writeSomeData(data) -> #of bytes written | CONNECTION_LOST
         This writes as much data as possible to the socket and returns either
         the number of bytes read (which is positive) or a connection error code
@@ -123,7 +200,10 @@
         # was inherited by a subprocess. And we really do want to close the
         # connection. So we use shutdown() instead.
         try:
-            self.socket.shutdown(2)
+            if self.TLS:
+                self.socket.sock_shutdown(2)
+            else:
+                self.socket.shutdown(2)
         except socket.error:
             pass
 
@@ -162,11 +242,45 @@
     def setTcpNoDelay(self, enabled):
         self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, enabled)
 
+    def getPeerCertificate(self):
+        """Return the certificate for the peer."""
+        return self.socket.get_peer_certificate()
+
+    def _postLoseConnection(self):
+        """Gets called after loseConnection(), after buffered data is sent.
+
+        We close the SSL transport layer, and if the other side hasn't
+        closed it yet we start reading, waiting for a ZeroReturnError
+        which will indicate the SSL shutdown has completed.
+        """
+        if self.TLS:
+            try:
+                done = self.socket.shutdown()
+                self.sslShutdown = 1
+            except SSL.Error:
+                return main.CONNECTION_LOST
+            if done:
+                return main.CONNECTION_DONE
+            else:
+                # we wait for other side to close SSL connection -
+                # this will be signaled by SSL.ZeroReturnError when reading
+                # from the socket
+                self.stopWriting()
+                self.startReading()
+                return None # don't close socket just yet
+        else:
+            print "quiting"
+            return main.CONNECTION_DONE
 
 class BaseClient(Connection):
     """A base class for client TCP (and similiar) sockets.
     """
 
+    def starttls(self, ctx):
+        holder = Connection.starttls(self, ctx)
+        self.socket.set_connect_state()
+        return holder
+
     def _finishInit(self, whenDone, skt, error, reactor):
         """Called by base classes to continue to next stage of initialization."""
         if whenDone:
@@ -360,6 +474,12 @@
         self.startReading()
         self.connected = 1
 
+
+    def starttls(self, ctx):
+        holder = Connection.starttls(self, ctx)
+        self.socket.set_accept_state()
+        return holder
+
     def __repr__(self):
         """A string representation of this connection.
         """
@@ -489,6 +609,11 @@
                         self.numberAccepts = i
                         break
                     raise
+                # XXX Hummmmmmmmmmmmmm what to do about this
+                except SSL.Error:
+                    if self.TLS:
+                        log.deferr()
+                        return
                 protocol = self.factory.buildProtocol(addr)
                 if protocol is None:
                     skt.close()


More information about the Twisted-Python mailing list