[Twisted-Python] TLS support for twisted (PATCH)

Skinny Puppy skin_pup-twisted at damnable.happypoo.com
Wed Feb 5 23:04:44 MST 2003


Itamar Shtull-Trauring [twisted at itamarst.org] wrote:
> Skinny Puppy <skin_pup-twisted at damnable.happypoo.com> wrote:
> 
> > Attached is the patch for TLS support within twisted.  It is not the
> > best code,  but it passes all the SSL test (None working on that next)
> > and works.  The diff also adds an echoclient_tls.py/echoserv_tls.py,
> > that work along with the echoclient_ssl.py/echoserv_tls.py
> 
> Neato! I don't have time to look at it right now, but it *will* be
> looked at sooner or later. If I don't check this in within a week bug me
> personally.

Will be more then happy too, but in the time being I got bored and
checked out the CVS version of Twisted and worked in TLS.  Patch attached 
of course.  Boredom will be my downfall, so please don't think I am being
pushy or anything.  :)

Jeremy Rossi 
-------------- next part --------------
? TLS.diff
? TLS.patch
? doc/examples/echoclient_tls.py
? doc/examples/echoserv_tls.py
Index: twisted/internet/ssl.py
===================================================================
RCS file: /cvs/Twisted/twisted/internet/ssl.py,v
retrieving revision 1.37
diff -u -r1.37 ssl.py
--- twisted/internet/ssl.py	3 Feb 2003 18:55:59 -0000	1.37
+++ twisted/internet/ssl.py	6 Feb 2003 05:50:16 -0000
@@ -95,112 +95,18 @@
         return SSL.Context(SSL.SSLv3_METHOD)
 
 
-class Connection(tcp.Connection):
-    """I am an SSL connection.
-    """
-
-    __implements__ = tcp.Connection.__implements__, interfaces.ISSLTransport
-    
-    writeBlockedOnRead = 0
-    readBlockedOnWrite= 0
-    sslShutdown = 0
-    
-    def getPeerCertificate(self):
-        """Return the certificate for the peer."""
-        return self.socket.get_peer_certificate()
-
-    def _postLoseConnection(self):
-        """Gets called after loseConnection(), after buffered data is sent.
-
-        We close the SSL transport layer, and if the other side hasn't
-        closed it yet we start reading, waiting for a ZeroReturnError
-        which will indicate the SSL shutdown has completed.
-        """
-        try:
-            done = self.socket.shutdown()
-            self.sslShutdown = 1
-        except SSL.Error:
-            return main.CONNECTION_LOST
-        if done:
-            return main.CONNECTION_DONE
-        else:
-            # we wait for other side to close SSL connection -
-            # this will be signaled by SSL.ZeroReturnError when reading
-            # from the socket
-            self.stopWriting()
-            self.startReading()
-            return None # don't close socket just yet
-    
-    def doRead(self):
-        """See tcp.Connection.doRead for details.
-        """
-        if self.writeBlockedOnRead:
-            self.writeBlockedOnRead = 0
-            return self.doWrite()
-        try:
-            return tcp.Connection.doRead(self)
-        except SSL.ZeroReturnError:
-            # close SSL layer, since other side has done so, if we haven't
-            if not self.sslShutdown:
-                try:
-                    self.socket.shutdown()
-                    self.sslShutdown = 1
-                except SSL.Error:
-                    pass
-            return main.CONNECTION_DONE
-        except SSL.WantReadError:
-            return
-        except SSL.WantWriteError:
-            self.readBlockedOnWrite = 1
-            self.startWriting()
-            return
-        except SSL.Error:
-            return main.CONNECTION_LOST
-
-    def doWrite(self):
-        if self.readBlockedOnWrite:
-            self.readBlockedOnWrite = 0
-            if not self.unsent: self.stopWriting()
-            return self.doRead()
-        return tcp.Connection.doWrite(self)
-    
-    def writeSomeData(self, data):
-        """See tcp.Connection.writeSomeData for details.
-        """
-        if not data:
-            return 0
 
-        try:
-            return tcp.Connection.writeSomeData(self, data)
-        except SSL.WantWriteError:
-            return 0
-        except SSL.WantReadError:
-            self.writeBlockedOnRead = 1
-            return 0
-        except SSL.Error:
-            return main.CONNECTION_LOST
-
-    def _closeSocket(self):
-        """Called to close our socket."""
-        try:
-            self.socket.sock_shutdown(2)
-        except socket.error:
-            pass
 
-
-
-class Client(Connection, tcp.Client):
+class Client(tcp.Client):
     """I am an SSL client."""
     def __init__(self, host, port, bindAddress, ctxFactory, connector, reactor=None):
         # tcp.Client.__init__ depends on self.ctxFactory being set
         self.ctxFactory = ctxFactory
         tcp.Client.__init__(self, host, port, bindAddress, connector, reactor)
 
-    def createInternetSocket(self):
-        """(internal) create an SSL socket
-        """
-        sock = tcp.Client.createInternetSocket(self)
-        return SSL.Connection(self.ctxFactory.getContext(), sock)
+    def _finishInit(self, whenDone, skt, error, reactor):
+        tcp.Client._finishInit(self, whenDone, skt, error, reactor)
+        self.starttls(self.ctxFactory)
 
     def getHost(self):
         """Returns a tuple of ('SSL', hostname, port).
@@ -219,7 +125,7 @@
 
 
 
-class Server(Connection, tcp.Server):
+class Server(tcp.Server):
     """I am an SSL server.
     """
     
@@ -256,32 +162,9 @@
         sock = tcp.Port.createInternetSocket(self)
         return SSL.Connection(self.ctxFactory.getContext(), sock)
     
-    def doRead(self):
-        """Called when my socket is ready for reading.
-
-        This accepts a connection and calls self.protocol() to handle the
-        wire-level protocol.
-        """
-        try:
-            try:
-                skt, addr = self.socket.accept()
-            except socket.error, e:
-                if e.args[0] == tcp.EWOULDBLOCK:
-                    return
-                raise
-            except SSL.Error:
-                log.deferr()
-                return
-            protocol = self.factory.buildProtocol(addr)
-            if protocol is None:
-                skt.close()
-                return
-            s = self.sessionno
-            self.sessionno = s+1
-            transport = self.transport(skt, protocol, addr, self, s)
-            protocol.makeConnection(transport)
-        except:
-            log.deferr()
+    def _preMakeConnection(self, transport):
+        transport.TLS = 1
+        return tcp.Port._preMakeConnection(self, transport)
 
 
 class Connector(base.BaseConnector):
Index: twisted/internet/tcp.py
===================================================================
RCS file: /cvs/Twisted/twisted/internet/tcp.py,v
retrieving revision 1.105
diff -u -r1.105 tcp.py
--- twisted/internet/tcp.py	31 Jan 2003 23:50:20 -0000	1.105
+++ twisted/internet/tcp.py	6 Feb 2003 05:50:16 -0000
@@ -34,6 +34,10 @@
 import select
 import operator
 import warnings
+try:
+    from OpenSSL import SSL
+except ImportError:
+    SSL = None
 
 if os.name == 'nt':
     EINVAL      = 10022
@@ -78,6 +82,11 @@
 
     __implements__ = abstract.FileDescriptor.__implements__, interfaces.ITCPTransport
 
+    writeBlockedOnRead = 0
+    readBlockedOnWrite= 0
+    sslShutdown = 0
+    TLS = 0
+
     def __init__(self, skt, protocol, reactor=None):
         abstract.FileDescriptor.__init__(self, reactor=reactor)
         self.socket = skt
@@ -85,7 +94,42 @@
         self.fileno = skt.fileno
         self.protocol = protocol
 
+    def starttls(self, ctx):
+        self.socket = SSL.Connection(ctx.getContext(), self.socket)
+        self.fileno = self.socket.fileno
+        self.TLS = 1
+
     def doRead(self):
+        if self.TLS:
+            return self.doRead_TLS()
+        else:
+            return self.doRead_NOTLS()
+
+    def doRead_TLS(self):
+        if self.writeBlockedOnRead:
+            self.writeBlockedOnRead = 0
+            return self.doWrite()
+        try:
+            return self.doRead_NOTLS()
+        except SSL.ZeroReturnError:
+            # close SSL layer, since other side has done so, if we haven't
+            if not self.sslShutdown:
+                try:
+                    self.socket.shutdown()
+                    self.sslShutdown = 1
+                except SSL.Error:
+                    pass
+            return main.CONNECTION_DONE
+        except SSL.WantReadError:
+            return
+        except SSL.WantWriteError:
+            self.readBlockedOnWrite = 1
+            self.startWriting()
+            return
+        except SSL.Error:
+            return main.CONNECTION_LOST
+
+    def doRead_NOTLS(self):
         """Calls self.protocol.dataReceived with all available data.
 
         This reads up to self.bufferSize bytes of data from its socket, then
@@ -104,7 +148,22 @@
             return main.CONNECTION_LOST
         return self.protocol.dataReceived(data)
 
+    def doWrite(self):
+        if self.TLS:
+            if self.readBlockedOnWrite:
+                self.readBlockedOnWrite = 0
+                if not self.unsent: self.stopWriting()
+                return self.doRead()
+        return abstract.FileDescriptor.doWrite(self)
+
     def writeSomeData(self, data):
+        if self.TLS:
+            return self.writeSomeData_TLS(data)
+        else:
+            return self.writeSomeData_NOTLS(data)
+
+
+    def writeSomeData_NOTLS(self, data):
         """Connection.writeSomeData(data) -> #of bytes written | CONNECTION_LOST
         This writes as much data as possible to the socket and returns either
         the number of bytes read (which is positive) or a connection error code
@@ -118,6 +177,20 @@
             else:
                 return main.CONNECTION_LOST
 
+    def writeSomeData_TLS(self, data):
+        if not data:
+            return 0
+        try:
+            return self.writeSomeData_NOTLS(data)
+        except SSL.WantWriteError:
+            return 0
+        except SSL.WantReadError:
+            self.writeBlockedOnRead = 1
+            return 0
+        except SSL.Error:
+            return main.CONNECTION_LOST
+
+
     def _closeSocket(self):
         """Called to close our socket."""
         # This used to close() the socket, but that doesn't *really* close if
@@ -125,7 +198,10 @@
         # was inherited by a subprocess. And we really do want to close the
         # connection. So we use shutdown() instead.
         try:
-            self.socket.shutdown(2)
+            if self.TLS:
+                self.socket.sock_shutdown(2)
+            else:
+                self.socket.shutdown(2)
         except socket.error:
             pass
 
@@ -164,6 +240,32 @@
     def setTcpNoDelay(self, enabled):
         self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, enabled)
 
+    def _postLoseConnection(self):
+        """Gets called after loseConnection(), after buffered data is sent.
+
+        We close the SSL transport layer, and if the other side hasn't
+        closed it yet we start reading, waiting for a ZeroReturnError
+        which will indicate the SSL shutdown has completed.
+        """
+        if self.TLS:
+            try:
+                done = self.socket.shutdown()
+                self.sslShutdown = 1
+            except SSL.Error:
+                return main.CONNECTION_LOST
+            if done:
+                return main.CONNECTION_DONE
+            else:
+                # we wait for other side to close SSL connection -
+                # this will be signaled by SSL.ZeroReturnError when reading
+                # from the socket
+                self.stopWriting()
+                self.startReading()
+                return None # don't close socket just yet
+        else:
+            # XXX Should this be main.CONNECTION_DONE
+            #     or the returnwd value ever change?
+            return abstract.FileDescriptor._postLoseConnection(self)
 
 class BaseClient(Connection):
     """A base class for client TCP (and similiar) sockets.
@@ -181,6 +283,11 @@
         else:
             reactor.callLater(0, self.failIfNotConnected, error)
 
+    def starttls(self, ctx):
+        holder = Connection.starttls(self, ctx)
+        self.socket.set_connect_state()
+        return holder
+
     def stopConnecting(self):
         """Stop attempt to connect."""
         self.failIfNotConnected(error.UserError())
@@ -337,6 +444,11 @@
         """
         return self.repstr
 
+    def starttls(self, ctx):
+        holder = Connection.starttls(self, ctx)
+        self.socket.set_accept_state()
+        return holder
+        
     def getHost(self):
         """Returns a tuple of ('INET', hostname, port).
 
@@ -433,18 +545,27 @@
                         self.numberAccepts = i
                         break
                     raise
+                # XXX Hummmmmmmmmmmmmm what to do about this
+                except SSL.Error:
+                    if self.TLS:
+                        log.deferr()
+                        return
                 protocol = self.factory.buildProtocol(addr)
                 if protocol is None:
                     skt.close()
                     continue
                 s = self.sessionno
                 self.sessionno = s+1
-                transport = self.transport(skt, protocol, addr, self, s)
+                # XXX once again I just don't know if I should be doing this
+                transport = self._preMakeConnection(self.transport(skt, protocol, addr, self, s))
                 protocol.makeConnection(transport)
             else:
                 self.numberAccepts = self.numberAccepts+20
         except:
             log.deferr()
+
+    def _preMakeConnection(self, transport):
+        return transport
 
     def loseConnection(self):
         """Stop accepting connections on this port.
-------------- next part --------------

# Twisted, the Framework of Your Internet
# Copyright (C) 2001 Matthew W. Lefkowitz
# 
# This library is free software; you can redistribute it and/or
# modify it under the terms of version 2.1 of the GNU Lesser General Public
# License as published by the Free Software Foundation.
# 
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
# 
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

from twisted.internet.protocol import Protocol, Factory
from twisted.internet import udp, ssl

### Protocol Implementation

# This is just about the simplest possible protocol

x = ssl.DefaultOpenSSLContextFactory(privateKeyFileName="server.pem",
                                 certificateFileName="server.pem",
                                 sslmethod=ssl.SSL.TLSv1_METHOD)


class Echo(Protocol):
    def dataReceived(self, data):
        "As soon as any data is received, write it back."
        print data
        try:
            command, other = data.split(";", 1)
        except:
            command = data
            other = ""
        if command == "STARTTLS":
            print "starting TLS"
            self.transport.write("READY;ajshdakjsd\n")
            self.transport.starttls(x)
#            self.transport.socket.set_accept_state()
        elif command == "STOPTLS":
            print "stoping TLS"
            self.transport.write("STOPED;\n")
            self.transport.stoptls()
        else:
            self.transport.write(data)


### Persistent Application Builder

# This builds a .tap file
class EchoClientFactory(Factory):
    protocol = Echo
    def connectionFailed(self, connector, reason):
        print 'connection failed:', reason.getErrorMessage()

    def connectionLost(self, connector, reason):
        print 'connection lost:', reason.getErrorMessage()


if __name__ == '__main__':
    # Since this is persistent, it's important to get the module naming right
    # (If we just used Echo, then it would be __main__.Echo when it attempted
    # to unpickle)
    import echoserv_tls
    from twisted.internet.app import Application
    factory = echoserv_tls.EchoClientFactory()
    factory.protocol = echoserv_tls.Echo
    app = Application("echo-tls")
    app.listenTCP(8000,factory)
#    app.listenUDP(8000, factory)
    app.save("start")
-------------- next part --------------
#!/usr/bin/python
# Twisted, the Framework of Your Internet
# Copyright (C) 2001 Matthew W. Lefkowitz
# 
# This library is free software; you can redistribute it and/or
# modify it under the terms of version 2.1 of the GNU Lesser General Public
# License as published by the Free Software Foundation.
# 
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
# 
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

from twisted.internet.protocol import ClientFactory, Protocol
from twisted.internet.app import Application
from twisted.internet import reactor, tcp, ssl
import sys

class myContext(ssl.ClientContextFactory):
    isClient = 1
    def getContext(self):
        return ssl.SSL.Context(ssl.SSL.TLSv1_METHOD)

x = myContext()

class EchoClient(Protocol):
    end="Bye-bye!"
    def connectionMade(self):
        self.transport.write("I am sending this in the clear\n")
        self.transport.write("And why should I not?\n")
        self.transport.write("STARTTLS;\n")

    def dataReceived(self, data):
     	for i in data.split("\n"):
            try:
                command, other = i.split(";", 1)
            except:
                command = ""
                other = i
            if command==self.end:
                self.transport.loseConnection()
            elif command=="READY":
                self.transport.starttls(x)
                self.transport.write("Spooks cannot see me now.\n")
                print i
            elif command=="STOPED":
                self.transport.stoptls()
                self.transport.write("they can see again\n")
                self.transport.write("%s;ok\n"%(self.end))

class EchoClientFactory(ClientFactory):
    protocol = EchoClient

    def clientConnectionFailed(self, connector, reason):
        print 'connection failed:', reason.getErrorMessage()
        reactor.stop()

    def clientConnectionLost(self, connector, reason):
        print 'connection lost:', reason.getErrorMessage()
        reactor.stop()

factory = EchoClientFactory()
reactor.connectTCP('localhost', 8000, factory)
reactor.run()


More information about the Twisted-Python mailing list