[Twisted-Python] TLS support for twisted (PATCH)
Skinny Puppy
skin_pup-twisted at damnable.happypoo.com
Wed Feb 5 23:04:44 MST 2003
Itamar Shtull-Trauring [twisted at itamarst.org] wrote:
> Skinny Puppy <skin_pup-twisted at damnable.happypoo.com> wrote:
>
> > Attached is the patch for TLS support within twisted. It is not the
> > best code, but it passes all the SSL test (None working on that next)
> > and works. The diff also adds an echoclient_tls.py/echoserv_tls.py,
> > that work along with the echoclient_ssl.py/echoserv_tls.py
>
> Neato! I don't have time to look at it right now, but it *will* be
> looked at sooner or later. If I don't check this in within a week bug me
> personally.
Will be more then happy too, but in the time being I got bored and
checked out the CVS version of Twisted and worked in TLS. Patch attached
of course. Boredom will be my downfall, so please don't think I am being
pushy or anything. :)
Jeremy Rossi
-------------- next part --------------
? TLS.diff
? TLS.patch
? doc/examples/echoclient_tls.py
? doc/examples/echoserv_tls.py
Index: twisted/internet/ssl.py
===================================================================
RCS file: /cvs/Twisted/twisted/internet/ssl.py,v
retrieving revision 1.37
diff -u -r1.37 ssl.py
--- twisted/internet/ssl.py 3 Feb 2003 18:55:59 -0000 1.37
+++ twisted/internet/ssl.py 6 Feb 2003 05:50:16 -0000
@@ -95,112 +95,18 @@
return SSL.Context(SSL.SSLv3_METHOD)
-class Connection(tcp.Connection):
- """I am an SSL connection.
- """
-
- __implements__ = tcp.Connection.__implements__, interfaces.ISSLTransport
-
- writeBlockedOnRead = 0
- readBlockedOnWrite= 0
- sslShutdown = 0
-
- def getPeerCertificate(self):
- """Return the certificate for the peer."""
- return self.socket.get_peer_certificate()
-
- def _postLoseConnection(self):
- """Gets called after loseConnection(), after buffered data is sent.
-
- We close the SSL transport layer, and if the other side hasn't
- closed it yet we start reading, waiting for a ZeroReturnError
- which will indicate the SSL shutdown has completed.
- """
- try:
- done = self.socket.shutdown()
- self.sslShutdown = 1
- except SSL.Error:
- return main.CONNECTION_LOST
- if done:
- return main.CONNECTION_DONE
- else:
- # we wait for other side to close SSL connection -
- # this will be signaled by SSL.ZeroReturnError when reading
- # from the socket
- self.stopWriting()
- self.startReading()
- return None # don't close socket just yet
-
- def doRead(self):
- """See tcp.Connection.doRead for details.
- """
- if self.writeBlockedOnRead:
- self.writeBlockedOnRead = 0
- return self.doWrite()
- try:
- return tcp.Connection.doRead(self)
- except SSL.ZeroReturnError:
- # close SSL layer, since other side has done so, if we haven't
- if not self.sslShutdown:
- try:
- self.socket.shutdown()
- self.sslShutdown = 1
- except SSL.Error:
- pass
- return main.CONNECTION_DONE
- except SSL.WantReadError:
- return
- except SSL.WantWriteError:
- self.readBlockedOnWrite = 1
- self.startWriting()
- return
- except SSL.Error:
- return main.CONNECTION_LOST
-
- def doWrite(self):
- if self.readBlockedOnWrite:
- self.readBlockedOnWrite = 0
- if not self.unsent: self.stopWriting()
- return self.doRead()
- return tcp.Connection.doWrite(self)
-
- def writeSomeData(self, data):
- """See tcp.Connection.writeSomeData for details.
- """
- if not data:
- return 0
- try:
- return tcp.Connection.writeSomeData(self, data)
- except SSL.WantWriteError:
- return 0
- except SSL.WantReadError:
- self.writeBlockedOnRead = 1
- return 0
- except SSL.Error:
- return main.CONNECTION_LOST
-
- def _closeSocket(self):
- """Called to close our socket."""
- try:
- self.socket.sock_shutdown(2)
- except socket.error:
- pass
-
-
-class Client(Connection, tcp.Client):
+class Client(tcp.Client):
"""I am an SSL client."""
def __init__(self, host, port, bindAddress, ctxFactory, connector, reactor=None):
# tcp.Client.__init__ depends on self.ctxFactory being set
self.ctxFactory = ctxFactory
tcp.Client.__init__(self, host, port, bindAddress, connector, reactor)
- def createInternetSocket(self):
- """(internal) create an SSL socket
- """
- sock = tcp.Client.createInternetSocket(self)
- return SSL.Connection(self.ctxFactory.getContext(), sock)
+ def _finishInit(self, whenDone, skt, error, reactor):
+ tcp.Client._finishInit(self, whenDone, skt, error, reactor)
+ self.starttls(self.ctxFactory)
def getHost(self):
"""Returns a tuple of ('SSL', hostname, port).
@@ -219,7 +125,7 @@
-class Server(Connection, tcp.Server):
+class Server(tcp.Server):
"""I am an SSL server.
"""
@@ -256,32 +162,9 @@
sock = tcp.Port.createInternetSocket(self)
return SSL.Connection(self.ctxFactory.getContext(), sock)
- def doRead(self):
- """Called when my socket is ready for reading.
-
- This accepts a connection and calls self.protocol() to handle the
- wire-level protocol.
- """
- try:
- try:
- skt, addr = self.socket.accept()
- except socket.error, e:
- if e.args[0] == tcp.EWOULDBLOCK:
- return
- raise
- except SSL.Error:
- log.deferr()
- return
- protocol = self.factory.buildProtocol(addr)
- if protocol is None:
- skt.close()
- return
- s = self.sessionno
- self.sessionno = s+1
- transport = self.transport(skt, protocol, addr, self, s)
- protocol.makeConnection(transport)
- except:
- log.deferr()
+ def _preMakeConnection(self, transport):
+ transport.TLS = 1
+ return tcp.Port._preMakeConnection(self, transport)
class Connector(base.BaseConnector):
Index: twisted/internet/tcp.py
===================================================================
RCS file: /cvs/Twisted/twisted/internet/tcp.py,v
retrieving revision 1.105
diff -u -r1.105 tcp.py
--- twisted/internet/tcp.py 31 Jan 2003 23:50:20 -0000 1.105
+++ twisted/internet/tcp.py 6 Feb 2003 05:50:16 -0000
@@ -34,6 +34,10 @@
import select
import operator
import warnings
+try:
+ from OpenSSL import SSL
+except ImportError:
+ SSL = None
if os.name == 'nt':
EINVAL = 10022
@@ -78,6 +82,11 @@
__implements__ = abstract.FileDescriptor.__implements__, interfaces.ITCPTransport
+ writeBlockedOnRead = 0
+ readBlockedOnWrite= 0
+ sslShutdown = 0
+ TLS = 0
+
def __init__(self, skt, protocol, reactor=None):
abstract.FileDescriptor.__init__(self, reactor=reactor)
self.socket = skt
@@ -85,7 +94,42 @@
self.fileno = skt.fileno
self.protocol = protocol
+ def starttls(self, ctx):
+ self.socket = SSL.Connection(ctx.getContext(), self.socket)
+ self.fileno = self.socket.fileno
+ self.TLS = 1
+
def doRead(self):
+ if self.TLS:
+ return self.doRead_TLS()
+ else:
+ return self.doRead_NOTLS()
+
+ def doRead_TLS(self):
+ if self.writeBlockedOnRead:
+ self.writeBlockedOnRead = 0
+ return self.doWrite()
+ try:
+ return self.doRead_NOTLS()
+ except SSL.ZeroReturnError:
+ # close SSL layer, since other side has done so, if we haven't
+ if not self.sslShutdown:
+ try:
+ self.socket.shutdown()
+ self.sslShutdown = 1
+ except SSL.Error:
+ pass
+ return main.CONNECTION_DONE
+ except SSL.WantReadError:
+ return
+ except SSL.WantWriteError:
+ self.readBlockedOnWrite = 1
+ self.startWriting()
+ return
+ except SSL.Error:
+ return main.CONNECTION_LOST
+
+ def doRead_NOTLS(self):
"""Calls self.protocol.dataReceived with all available data.
This reads up to self.bufferSize bytes of data from its socket, then
@@ -104,7 +148,22 @@
return main.CONNECTION_LOST
return self.protocol.dataReceived(data)
+ def doWrite(self):
+ if self.TLS:
+ if self.readBlockedOnWrite:
+ self.readBlockedOnWrite = 0
+ if not self.unsent: self.stopWriting()
+ return self.doRead()
+ return abstract.FileDescriptor.doWrite(self)
+
def writeSomeData(self, data):
+ if self.TLS:
+ return self.writeSomeData_TLS(data)
+ else:
+ return self.writeSomeData_NOTLS(data)
+
+
+ def writeSomeData_NOTLS(self, data):
"""Connection.writeSomeData(data) -> #of bytes written | CONNECTION_LOST
This writes as much data as possible to the socket and returns either
the number of bytes read (which is positive) or a connection error code
@@ -118,6 +177,20 @@
else:
return main.CONNECTION_LOST
+ def writeSomeData_TLS(self, data):
+ if not data:
+ return 0
+ try:
+ return self.writeSomeData_NOTLS(data)
+ except SSL.WantWriteError:
+ return 0
+ except SSL.WantReadError:
+ self.writeBlockedOnRead = 1
+ return 0
+ except SSL.Error:
+ return main.CONNECTION_LOST
+
+
def _closeSocket(self):
"""Called to close our socket."""
# This used to close() the socket, but that doesn't *really* close if
@@ -125,7 +198,10 @@
# was inherited by a subprocess. And we really do want to close the
# connection. So we use shutdown() instead.
try:
- self.socket.shutdown(2)
+ if self.TLS:
+ self.socket.sock_shutdown(2)
+ else:
+ self.socket.shutdown(2)
except socket.error:
pass
@@ -164,6 +240,32 @@
def setTcpNoDelay(self, enabled):
self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, enabled)
+ def _postLoseConnection(self):
+ """Gets called after loseConnection(), after buffered data is sent.
+
+ We close the SSL transport layer, and if the other side hasn't
+ closed it yet we start reading, waiting for a ZeroReturnError
+ which will indicate the SSL shutdown has completed.
+ """
+ if self.TLS:
+ try:
+ done = self.socket.shutdown()
+ self.sslShutdown = 1
+ except SSL.Error:
+ return main.CONNECTION_LOST
+ if done:
+ return main.CONNECTION_DONE
+ else:
+ # we wait for other side to close SSL connection -
+ # this will be signaled by SSL.ZeroReturnError when reading
+ # from the socket
+ self.stopWriting()
+ self.startReading()
+ return None # don't close socket just yet
+ else:
+ # XXX Should this be main.CONNECTION_DONE
+ # or the returnwd value ever change?
+ return abstract.FileDescriptor._postLoseConnection(self)
class BaseClient(Connection):
"""A base class for client TCP (and similiar) sockets.
@@ -181,6 +283,11 @@
else:
reactor.callLater(0, self.failIfNotConnected, error)
+ def starttls(self, ctx):
+ holder = Connection.starttls(self, ctx)
+ self.socket.set_connect_state()
+ return holder
+
def stopConnecting(self):
"""Stop attempt to connect."""
self.failIfNotConnected(error.UserError())
@@ -337,6 +444,11 @@
"""
return self.repstr
+ def starttls(self, ctx):
+ holder = Connection.starttls(self, ctx)
+ self.socket.set_accept_state()
+ return holder
+
def getHost(self):
"""Returns a tuple of ('INET', hostname, port).
@@ -433,18 +545,27 @@
self.numberAccepts = i
break
raise
+ # XXX Hummmmmmmmmmmmmm what to do about this
+ except SSL.Error:
+ if self.TLS:
+ log.deferr()
+ return
protocol = self.factory.buildProtocol(addr)
if protocol is None:
skt.close()
continue
s = self.sessionno
self.sessionno = s+1
- transport = self.transport(skt, protocol, addr, self, s)
+ # XXX once again I just don't know if I should be doing this
+ transport = self._preMakeConnection(self.transport(skt, protocol, addr, self, s))
protocol.makeConnection(transport)
else:
self.numberAccepts = self.numberAccepts+20
except:
log.deferr()
+
+ def _preMakeConnection(self, transport):
+ return transport
def loseConnection(self):
"""Stop accepting connections on this port.
-------------- next part --------------
# Twisted, the Framework of Your Internet
# Copyright (C) 2001 Matthew W. Lefkowitz
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of version 2.1 of the GNU Lesser General Public
# License as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
from twisted.internet.protocol import Protocol, Factory
from twisted.internet import udp, ssl
### Protocol Implementation
# This is just about the simplest possible protocol
x = ssl.DefaultOpenSSLContextFactory(privateKeyFileName="server.pem",
certificateFileName="server.pem",
sslmethod=ssl.SSL.TLSv1_METHOD)
class Echo(Protocol):
def dataReceived(self, data):
"As soon as any data is received, write it back."
print data
try:
command, other = data.split(";", 1)
except:
command = data
other = ""
if command == "STARTTLS":
print "starting TLS"
self.transport.write("READY;ajshdakjsd\n")
self.transport.starttls(x)
# self.transport.socket.set_accept_state()
elif command == "STOPTLS":
print "stoping TLS"
self.transport.write("STOPED;\n")
self.transport.stoptls()
else:
self.transport.write(data)
### Persistent Application Builder
# This builds a .tap file
class EchoClientFactory(Factory):
protocol = Echo
def connectionFailed(self, connector, reason):
print 'connection failed:', reason.getErrorMessage()
def connectionLost(self, connector, reason):
print 'connection lost:', reason.getErrorMessage()
if __name__ == '__main__':
# Since this is persistent, it's important to get the module naming right
# (If we just used Echo, then it would be __main__.Echo when it attempted
# to unpickle)
import echoserv_tls
from twisted.internet.app import Application
factory = echoserv_tls.EchoClientFactory()
factory.protocol = echoserv_tls.Echo
app = Application("echo-tls")
app.listenTCP(8000,factory)
# app.listenUDP(8000, factory)
app.save("start")
-------------- next part --------------
#!/usr/bin/python
# Twisted, the Framework of Your Internet
# Copyright (C) 2001 Matthew W. Lefkowitz
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of version 2.1 of the GNU Lesser General Public
# License as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
from twisted.internet.protocol import ClientFactory, Protocol
from twisted.internet.app import Application
from twisted.internet import reactor, tcp, ssl
import sys
class myContext(ssl.ClientContextFactory):
isClient = 1
def getContext(self):
return ssl.SSL.Context(ssl.SSL.TLSv1_METHOD)
x = myContext()
class EchoClient(Protocol):
end="Bye-bye!"
def connectionMade(self):
self.transport.write("I am sending this in the clear\n")
self.transport.write("And why should I not?\n")
self.transport.write("STARTTLS;\n")
def dataReceived(self, data):
for i in data.split("\n"):
try:
command, other = i.split(";", 1)
except:
command = ""
other = i
if command==self.end:
self.transport.loseConnection()
elif command=="READY":
self.transport.starttls(x)
self.transport.write("Spooks cannot see me now.\n")
print i
elif command=="STOPED":
self.transport.stoptls()
self.transport.write("they can see again\n")
self.transport.write("%s;ok\n"%(self.end))
class EchoClientFactory(ClientFactory):
protocol = EchoClient
def clientConnectionFailed(self, connector, reason):
print 'connection failed:', reason.getErrorMessage()
reactor.stop()
def clientConnectionLost(self, connector, reason):
print 'connection lost:', reason.getErrorMessage()
reactor.stop()
factory = EchoClientFactory()
reactor.connectTCP('localhost', 8000, factory)
reactor.run()
More information about the Twisted-Python
mailing list