[Twisted-Python] TLS support for twisted (PATCH)
Skinny Puppy
skin_pup-twisted at damnable.happypoo.com
Wed Feb 5 20:16:32 EST 2003
Attached is the patch for TLS support within twisted. It is not the best
code, but it passes all the SSL test (None working on that next) and
works. The diff also adds an echoclient_tls.py/echoserv_tls.py, that
work along with the echoclient_ssl.py/echoserv_tls.py
Some things I did have questions about how best to handle:
- What error should starttls raise if OpenSSL is not installed? exarkun
suggested the following but I have no idea how to do it cleanly. Any
ideas?
exarkun : currently, connect/listenSSL aren't defined at all if
the ssl support libs can't be implemented
exarkun : might it make sense to just duplicate that?
- the doRead_TLS/doRead_NOTLS and writeSomeData_TLS/writeSomeData_NOTLS
are just damn ugly and could be combined but I felt it was best to leave
the apart for now. Is OK, should I merge them?
Jeremy Rossi
-------------- next part --------------
diff -urP Twisted-1.0.2alpha4/doc/examples/echoclient_tls.py Twisted-1.0.2alpha4-halfass-branch/doc/examples/echoclient_tls.py
--- Twisted-1.0.2alpha4/doc/examples/echoclient_tls.py Wed Dec 31 19:00:00 1969
+++ Twisted-1.0.2alpha4-halfass-branch/doc/examples/echoclient_tls.py Wed Feb 5 19:21:06 2003
@@ -0,0 +1,65 @@
+#!/usr/bin/python
+# Twisted, the Framework of Your Internet
+# Copyright (C) 2001 Matthew W. Lefkowitz
+#
+# This library is free software; you can redistribute it and/or
+# modify it under the terms of version 2.1 of the GNU Lesser General Public
+# License as published by the Free Software Foundation.
+#
+# This library is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+# Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public
+# License along with this library; if not, write to the Free Software
+# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
+
+from twisted.internet.protocol import ClientFactory, Protocol
+from twisted.internet.app import Application
+from twisted.internet import reactor, tcp, ssl
+import sys
+
+class myContext(ssl.ClientContextFactory):
+ isClient = 1
+ def getContext(self):
+ return ssl.SSL.Context(ssl.SSL.TLSv1_METHOD)
+
+x = myContext()
+
+class EchoClient(Protocol):
+ end="Bye-bye!"
+ def connectionMade(self):
+ self.transport.write("I am sending this in the clear\n")
+ self.transport.write("And why should I not?\n")
+ self.transport.write("STARTTLS;\n")
+
+ def dataReceived(self, data):
+ for i in data.split("\n"):
+ try:
+ command, other = i.split(";", 1)
+ except:
+ command = ""
+ other = i
+ if command==self.end:
+ self.transport.loseConnection()
+ elif command=="READY":
+ self.transport.starttls(x)
+ self.transport.write("Spooks cannot see me now.\n")
+ else:
+ print i
+
+class EchoClientFactory(ClientFactory):
+ protocol = EchoClient
+
+ def clientConnectionFailed(self, connector, reason):
+ print 'connection failed:', reason.getErrorMessage()
+ reactor.stop()
+
+ def clientConnectionLost(self, connector, reason):
+ print 'connection lost:', reason.getErrorMessage()
+ reactor.stop()
+
+factory = EchoClientFactory()
+reactor.connectTCP('localhost', 8000, factory)
+reactor.run()
diff -urP Twisted-1.0.2alpha4/doc/examples/echoserv_tls.py Twisted-1.0.2alpha4-halfass-branch/doc/examples/echoserv_tls.py
--- Twisted-1.0.2alpha4/doc/examples/echoserv_tls.py Wed Dec 31 19:00:00 1969
+++ Twisted-1.0.2alpha4-halfass-branch/doc/examples/echoserv_tls.py Wed Feb 5 19:20:55 2003
@@ -0,0 +1,71 @@
+
+# Twisted, the Framework of Your Internet
+# Copyright (C) 2001 Matthew W. Lefkowitz
+#
+# This library is free software; you can redistribute it and/or
+# modify it under the terms of version 2.1 of the GNU Lesser General Public
+# License as published by the Free Software Foundation.
+#
+# This library is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+# Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public
+# License along with this library; if not, write to the Free Software
+# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
+
+from twisted.internet.protocol import Protocol, Factory
+from twisted.internet import udp, ssl
+
+### Protocol Implementation
+
+# This is just about the simplest possible protocol
+
+x = ssl.DefaultOpenSSLContextFactory(privateKeyFileName="server.pem",
+ certificateFileName="server.pem",
+ sslmethod=ssl.SSL.TLSv1_METHOD)
+
+
+class Echo(Protocol):
+ def dataReceived(self, data):
+ "As soon as any data is received, write it back."
+ print data
+ try:
+ command, other = data.split(";", 1)
+ except:
+ command = data
+ other = ""
+ if command == "STARTTLS":
+ print "starting TLS"
+ self.transport.write("READY;ajshdakjsd\n")
+ self.transport.starttls(x)
+# self.transport.socket.set_accept_state()
+ else:
+ self.transport.write(data)
+
+
+### Persistent Application Builder
+
+# This builds a .tap file
+class EchoClientFactory(Factory):
+ protocol = Echo
+ def connectionFailed(self, connector, reason):
+ print 'connection failed:', reason.getErrorMessage()
+
+ def connectionLost(self, connector, reason):
+ print 'connection lost:', reason.getErrorMessage()
+
+
+if __name__ == '__main__':
+ # Since this is persistent, it's important to get the module naming right
+ # (If we just used Echo, then it would be __main__.Echo when it attempted
+ # to unpickle)
+ import echoserv_tls
+ from twisted.internet.app import Application
+ factory = echoserv_tls.EchoClientFactory()
+ factory.protocol = echoserv_tls.Echo
+ app = Application("echo-tls")
+ app.listenTCP(8000,factory)
+# app.listenUDP(8000, factory)
+ app.save("start")
diff -urP Twisted-1.0.2alpha4/twisted/internet/ssl.py Twisted-1.0.2alpha4-halfass-branch/twisted/internet/ssl.py
--- Twisted-1.0.2alpha4/twisted/internet/ssl.py Wed Jan 8 09:18:53 2003
+++ Twisted-1.0.2alpha4-halfass-branch/twisted/internet/ssl.py Wed Feb 5 19:20:06 2003
@@ -40,11 +40,11 @@
import socket
# sibling imports
-import tcp, main, interfaces
+import main, interfaces, tcp
# Twisted imports
from twisted.python import log
-
+#from twisted.internet import reactor
class ContextFactory:
"""A factory for SSL context objects, for server SSL connections."""
@@ -95,101 +95,7 @@
return SSL.Context(SSL.SSLv3_METHOD)
-class Connection(tcp.Connection):
- """I am an SSL connection.
- """
-
- __implements__ = tcp.Connection.__implements__, interfaces.ISSLTransport
-
- writeBlockedOnRead = 0
- readBlockedOnWrite= 0
- sslShutdown = 0
-
- def getPeerCertificate(self):
- """Return the certificate for the peer."""
- return self.socket.get_peer_certificate()
-
- def _postLoseConnection(self):
- """Gets called after loseConnection(), after buffered data is sent.
-
- We close the SSL transport layer, and if the other side hasn't
- closed it yet we start reading, waiting for a ZeroReturnError
- which will indicate the SSL shutdown has completed.
- """
- try:
- done = self.socket.shutdown()
- self.sslShutdown = 1
- except SSL.Error:
- return main.CONNECTION_LOST
- if done:
- return main.CONNECTION_DONE
- else:
- # we wait for other side to close SSL connection -
- # this will be signaled by SSL.ZeroReturnError when reading
- # from the socket
- self.stopWriting()
- self.startReading()
- return None # don't close socket just yet
-
- def doRead(self):
- """See tcp.Connection.doRead for details.
- """
- if self.writeBlockedOnRead:
- self.writeBlockedOnRead = 0
- return self.doWrite()
- try:
- return tcp.Connection.doRead(self)
- except SSL.ZeroReturnError:
- # close SSL layer, since other side has done so, if we haven't
- if not self.sslShutdown:
- try:
- self.socket.shutdown()
- self.sslShutdown = 1
- except SSL.Error:
- pass
- return main.CONNECTION_DONE
- except SSL.WantReadError:
- return
- except SSL.WantWriteError:
- self.readBlockedOnWrite = 1
- self.startWriting()
- return
- except SSL.Error:
- return main.CONNECTION_LOST
-
- def doWrite(self):
- if self.readBlockedOnWrite:
- self.readBlockedOnWrite = 0
- if not self.unsent: self.stopWriting()
- return self.doRead()
- return tcp.Connection.doWrite(self)
-
- def writeSomeData(self, data):
- """See tcp.Connection.writeSomeData for details.
- """
- if not data:
- return 0
-
- try:
- return tcp.Connection.writeSomeData(self, data)
- except SSL.WantWriteError:
- return 0
- except SSL.WantReadError:
- self.writeBlockedOnRead = 1
- return 0
- except SSL.Error:
- return main.CONNECTION_LOST
-
- def _closeSocket(self):
- """Called to close our socket."""
- try:
- self.socket.sock_shutdown(2)
- except socket.error:
- pass
-
-
-
-class Client(Connection, tcp.TCPClient):
+class Client(tcp.TCPClient):
"""I am an SSL client.
"""
@@ -197,11 +103,15 @@
self.ctxFactory = ctxFactory
tcp.TCPClient.__init__(self, host, port, bindAddress, connector, reactor)
+ def _finishInit(self, whenDone, skt, error, reactor):
+ tcp.TCPClient._finishInit(self, whenDone, skt, error, reactor)
+ self.starttls(self.ctxFactory)
+
def createInternetSocket(self):
"""(internal) create an SSL socket
"""
sock = tcp.TCPClient.createInternetSocket(self)
- return SSL.Connection(self.ctxFactory.getContext(), sock)
+ return sock
def getHost(self):
"""Returns a tuple of ('SSL', hostname, port).
@@ -220,13 +130,14 @@
-class Server(Connection, tcp.Server):
+class Server(tcp.Connection, tcp.Server):
"""I am an SSL server.
"""
def __init__(*args, **kwargs):
# we need those so we don't use ssl.Connection's __init__
apply(tcp.Server.__init__, args, kwargs)
+
def getHost(self):
"""Returns a tuple of ('SSL', hostname, port).
@@ -258,7 +169,8 @@
"""(internal) create an SSL socket
"""
sock = tcp.Port.createInternetSocket(self)
- return SSL.Connection(self.ctxFactory.getContext(), sock)
+ sock = SSL.Connection(self.ctxFactory.getContext(), sock)
+ return sock
def doRead(self):
"""Called when my socket is ready for reading.
@@ -283,6 +195,7 @@
s = self.sessionno
self.sessionno = s+1
transport = self.transport(skt, protocol, addr, self, s)
+ transport.TLS = 1
protocol.makeConnection(transport)
except:
log.deferr()
diff -urP Twisted-1.0.2alpha4/twisted/internet/tcp.py Twisted-1.0.2alpha4-halfass-branch/twisted/internet/tcp.py
--- Twisted-1.0.2alpha4/twisted/internet/tcp.py Wed Jan 1 09:32:27 2003
+++ Twisted-1.0.2alpha4-halfass-branch/twisted/internet/tcp.py Wed Feb 5 19:20:02 2003
@@ -34,6 +34,10 @@
import select
import operator
import warnings
+try:
+ from OpenSSL import SSL
+except:
+ SSL = None
if os.name == 'nt':
EINVAL = 10022
@@ -73,6 +77,10 @@
This is an abstract superclass of all objects which represent a TCP/IP
connection based socket.
"""
+ writeBlockedOnRead = 0
+ readBlockedOnWrite= 0
+ sslShutdown = 0
+ TLS = 0
__implements__ = abstract.FileDescriptor.__implements__, interfaces.ITCPTransport
@@ -83,7 +91,47 @@
self.fileno = skt.fileno
self.protocol = protocol
+ def starttls(self, ctx):
+ self._oldsocket = self.socket
+ self.socket = SSL.Connection(ctx.getContext(), self.socket)
+ self.fileno = self.socket.fileno
+ self.TLS = 1
+
+ def stoptls(self):
+ self.socket = self._oldsocket
+ self.TLS = 0
+
def doRead(self):
+ if self.TLS:
+ return self.doRead_TLS()
+ else:
+ return self.doRead_NOTLS()
+
+ def doRead_TLS(self):
+ if self.writeBlockedOnRead:
+ self.writeBlockedOnRead = 0
+ return self.doWrite()
+ try:
+ return self.doRead_NOTLS()
+ except SSL.ZeroReturnError:
+ # close SSL layer, since other side has done so, if we haven't
+ if not self.sslShutdown:
+ try:
+ self.socket.shutdown()
+ self.sslShutdown = 1
+ except SSL.Error:
+ pass
+ return main.CONNECTION_DONE
+ except SSL.WantReadError:
+ return
+ except SSL.WantWriteError:
+ self.readBlockedOnWrite = 1
+ self.startWriting()
+ return
+ except SSL.Error:
+ return main.CONNECTION_LOST
+
+ def doRead_NOTLS(self):
"""Calls self.protocol.dataReceived with all available data.
This reads up to self.bufferSize bytes of data from its socket, then
@@ -102,7 +150,36 @@
return main.CONNECTION_LOST
return self.protocol.dataReceived(data)
+
+ def doWrite(self):
+ if self.TLS:
+ if self.readBlockedOnWrite:
+ self.readBlockedOnWrite = 0
+ if not self.unsent: self.stopWriting()
+ return self.doRead()
+ return abstract.FileDescriptor.doWrite(self)
+
def writeSomeData(self, data):
+ if self.TLS:
+ return self.writeSomeData_TLS(data)
+ else:
+ return self.writeSomeData_NOTLS(data)
+
+
+ def writeSomeData_TLS(self, data):
+ if not data:
+ return 0
+ try:
+ return self.writeSomeData_NOTLS(data)
+ except SSL.WantWriteError:
+ return 0
+ except SSL.WantReadError:
+ self.writeBlockedOnRead = 1
+ return 0
+ except SSL.Error:
+ return main.CONNECTION_LOST
+
+ def writeSomeData_NOTLS(self, data):
"""Connection.writeSomeData(data) -> #of bytes written | CONNECTION_LOST
This writes as much data as possible to the socket and returns either
the number of bytes read (which is positive) or a connection error code
@@ -123,7 +200,10 @@
# was inherited by a subprocess. And we really do want to close the
# connection. So we use shutdown() instead.
try:
- self.socket.shutdown(2)
+ if self.TLS:
+ self.socket.sock_shutdown(2)
+ else:
+ self.socket.shutdown(2)
except socket.error:
pass
@@ -162,11 +242,45 @@
def setTcpNoDelay(self, enabled):
self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, enabled)
+ def getPeerCertificate(self):
+ """Return the certificate for the peer."""
+ return self.socket.get_peer_certificate()
+
+ def _postLoseConnection(self):
+ """Gets called after loseConnection(), after buffered data is sent.
+
+ We close the SSL transport layer, and if the other side hasn't
+ closed it yet we start reading, waiting for a ZeroReturnError
+ which will indicate the SSL shutdown has completed.
+ """
+ if self.TLS:
+ try:
+ done = self.socket.shutdown()
+ self.sslShutdown = 1
+ except SSL.Error:
+ return main.CONNECTION_LOST
+ if done:
+ return main.CONNECTION_DONE
+ else:
+ # we wait for other side to close SSL connection -
+ # this will be signaled by SSL.ZeroReturnError when reading
+ # from the socket
+ self.stopWriting()
+ self.startReading()
+ return None # don't close socket just yet
+ else:
+ print "quiting"
+ return main.CONNECTION_DONE
class BaseClient(Connection):
"""A base class for client TCP (and similiar) sockets.
"""
+ def starttls(self, ctx):
+ holder = Connection.starttls(self, ctx)
+ self.socket.set_connect_state()
+ return holder
+
def _finishInit(self, whenDone, skt, error, reactor):
"""Called by base classes to continue to next stage of initialization."""
if whenDone:
@@ -360,6 +474,12 @@
self.startReading()
self.connected = 1
+
+ def starttls(self, ctx):
+ holder = Connection.starttls(self, ctx)
+ self.socket.set_accept_state()
+ return holder
+
def __repr__(self):
"""A string representation of this connection.
"""
@@ -489,6 +609,11 @@
self.numberAccepts = i
break
raise
+ # XXX Hummmmmmmmmmmmmm what to do about this
+ except SSL.Error:
+ if self.TLS:
+ log.deferr()
+ return
protocol = self.factory.buildProtocol(addr)
if protocol is None:
skt.close()
More information about the Twisted-Python
mailing list