[Twisted-Python] [PATCH] HTTP server fix for the FD problem

Andrew Bennetts andrew-twisted at puzzling.org
Sun May 11 05:28:42 EDT 2003


This adds timeouts to the HTTPChannel class, so that dead old sockets can't
hang around indefinitely and eat all the file descriptors.  The timeout
machinery is added to twisted/protocols/policies.py, and could be reused for
other protocols.

I'm a little concerned that I'm being a touch to evil, though:

-class HTTPChannel(basic.LineReceiver):
+class HTTPChannel(policies.TimeoutWrapper(basic.LineReceiver)):

I'm posting to the list rather than just checking it in so that people can
veto this sort of madness if they like :)  (Oh, and some testing that this
actually fixes the problem would be nice, too).

I'm also including Moshe's simpler, less general patch that only fixes
HTTPChannel, rather than being reusable, in case people prefer that
solution.

-Andrew.


**** My patch: ****

Index: twisted/protocols/policies.py
===================================================================
RCS file: /cvs/Twisted/twisted/protocols/policies.py,v
retrieving revision 1.8
diff -u -r1.8 policies.py
--- twisted/protocols/policies.py	10 Apr 2003 07:03:51 -0000	1.8
+++ twisted/protocols/policies.py	11 May 2003 09:07:38 -0000
@@ -90,6 +90,12 @@
         self.wrappedFactory = wrappedFactory
         self.protocols = {}
 
+    def startFactory(self):
+        self.wrappedFactory.startFactory()
+
+    def stopFactory(self):
+        self.wrappedFactory.stopFactory()
+
     def startedConnecting(self, connector):
         self.wrappedFactory.startedConnecting(connector)
 
@@ -282,6 +288,7 @@
     maxConnectionsPerPeer = 5
 
     def startFactory(self):
+        WrappingFactory.startFactory()
         self.peerConnections = {}
         
     def buildProtocol(self, addr):
@@ -298,6 +305,90 @@
         if self.peerConnections[peerHost] == 0:
             del self.peerConnections[peerHost]
 
+
+class _TransportWrapper:
+    def __init__(self, original, protocol):
+        self.original = original
+        self.protocol = protocol
+
+    def write(self, data):
+        self.original.write(data)
+        self.protocol.resetTimeout()
+
+    def writeSequence(self, sequence):
+        self.original.writeSequence(sequence)
+        self.protocol.resetTimeout()
+
+    def __getattr__(self, name):
+        return getattr(self.original, name)
+
+
+def TimeoutWrapper(protocolClass):
+    class _TimeoutProtocol(protocolClass, object):
+        """Protocol that automatically disconnects when the connection is idle.
+        
+        Stability: Unstable
+        """
+
+        def connectionMade(self):
+            self.timeoutCall = None
+            if getattr(self, 'timeoutPeriod', None):
+                timeoutPeriod = self.timeoutPeriod
+            else:
+                timeoutPeriod = self.factory.timeoutPeriod
+            self.setTimeout(timeoutPeriod)
+
+        def makeConnection(self, transport):
+            t = _TransportWrapper(transport, self)
+            protocolClass.makeConnection(self, t)
+
+        def setTimeout(self, timeoutPeriod=None):
+            """Set a timeout.
+            
+            This will cancel any existing timeouts.
+
+            @param timeoutPeriod: If not C{None}, change the timeout period.
+                Otherwise, use the existing value.
+            """
+            self.cancelTimeout()
+            if timeoutPeriod is not None:
+                self.timeoutPeriod = timeoutPeriod
+            self.timeoutCall = reactor.callLater(self.timeoutPeriod, self.timeoutFunc)
+
+        def cancelTimeout(self):
+            """Cancel the timeout.
+            
+            If the timeout was already cancelled, this does nothing.
+            """
+            if self.timeoutCall:
+                try:
+                    self.timeoutCall.cancel()
+                except error.AlreadyCalled:
+                    pass
+                self.timeoutCall = None
+        
+        def resetTimeout(self):
+            """Reset the timeout, usually because some activity just happened."""
+            if self.timeoutCall:
+                self.timeoutCall.reset(self.timeoutPeriod)
+
+        def dataReceived(self, data):
+            self.resetTimeout()
+            protocolClass.dataReceived(self, data)
+
+        def connectionLost(self, reason):
+            self.cancelTimeout()
+            protocolClass.connectionLost(self, reason)
+
+        def timeoutFunc(self):
+            """This method is called when the timeout is triggered.
+
+            By default it calls L{loseConnection}.  Override this if you want
+            something else to happen.
+            """
+            self.loseConnection()
+    return _TimeoutProtocol
+            
 
 class TimeoutProtocol(ProtocolWrapper):
     """Protocol that automatically disconnects when the connection is idle.
Index: twisted/protocols/http.py
===================================================================
RCS file: /cvs/Twisted/twisted/protocols/http.py,v
retrieving revision 1.79
diff -u -r1.79 http.py
--- twisted/protocols/http.py	8 May 2003 16:28:09 -0000	1.79
+++ twisted/protocols/http.py	11 May 2003 09:07:39 -0000
@@ -47,6 +47,7 @@
 
 # twisted imports
 from twisted.internet import interfaces, reactor, protocol
+from twisted.protocols import policies
 from twisted.python import log
 
 
@@ -856,7 +857,7 @@
         pass
 
 
-class HTTPChannel(basic.LineReceiver):
+class HTTPChannel(policies.TimeoutWrapper(basic.LineReceiver)):
     """A receiver for HTTP requests."""
 
     length = 0
@@ -868,9 +869,9 @@
     # set in instances or subclasses
     requestFactory = Request
 
-
     def __init__(self):
         # the request queue
+        super(HTTPChannel, self).__init__()
         self.requests = []
 
     def lineReceived(self, line):
@@ -944,6 +945,7 @@
 
         req = self.requests[-1]
         req.requestReceived(command, path, version)
+        self.cancelTimeout()
 
     def rawDataReceived(self, data):
         if len(data) < self.length:
@@ -1000,10 +1002,13 @@
             # notify next request it can start writing
             if self.requests:
                 self.requests[0].noLongerQueued()
+            else:
+                self.resetTimeout()
         else:
             self.transport.loseConnection()
 
     def connectionLost(self, reason):
+        super(HTTPChannel, self).connectionLost(reason)
         for request in self.requests:
             request.connectionLost(reason)
 
@@ -1014,6 +1019,8 @@
     protocol = HTTPChannel
 
     logPath = None
+
+    timeoutPeriod = 60
 
     def __init__(self, logPath=None):
         if logPath is not None:
Index: twisted/test/test_web.py
===================================================================
RCS file: /cvs/Twisted/twisted/test/test_web.py,v
retrieving revision 1.31
diff -u -r1.31 test_web.py
--- twisted/test/test_web.py	10 May 2003 08:30:04 -0000	1.31
+++ twisted/test/test_web.py	11 May 2003 09:07:39 -0000
@@ -201,6 +201,9 @@
                   "Accept: text/html"]:
             self.channel.lineReceived(l)
 
+    def tearDown(self):
+        self.channel.connectionLost(reason=None)
+
     def test_modified(self):
         """If-Modified-Since cache validator (positive)"""
         self.channel.lineReceived("If-Modified-Since: %s"


**** Moshe's patch: ****


Index: twisted/protocols/http.py
===================================================================
RCS file: /cvs/Twisted/twisted/protocols/http.py,v
retrieving revision 1.78
diff -u -r1.78 http.py
--- twisted/protocols/http.py	21 Apr 2003 02:51:52 -0000	1.78
+++ twisted/protocols/http.py	11 May 2003 06:29:26 -0000
@@ -850,6 +850,7 @@
     __header = ''
     __first_line = 1
     __content = None
+    terminateConnection = None
 
     # set in instances or subclasses
     requestFactory = Request
@@ -859,7 +860,14 @@
         # the request queue
         self.requests = []
 
+    def connectionMade(self):
+        self.timeout = self.factory.timeout
+        self.terminateConnection = reactor.callLater(self.timeout,
+                                                  self.transport.loseConnection)
+
     def lineReceived(self, line):
+        if self.terminateConnection:
+            self.terminateConnection.reset(self.timeout)
         if self.__first_line:
             # if this connection is not persistent, drop any data which
             # the client (illegally) sent after the last request.
@@ -930,8 +938,13 @@
 
         req = self.requests[-1]
         req.requestReceived(command, path, version)
+        if self.terminateConnection:
+            self.terminateConnection.cancel()
+            self.terminateConnection = None
 
     def rawDataReceived(self, data):
+        if self.terminateConnection:
+            self.terminateConnection.reset(self.timeout)
         if len(data) < self.length:
             self.requests[-1].handleContentChunk(data)
             self.length = self.length - len(data)
@@ -986,6 +999,9 @@
             # notify next request it can start writing
             if self.requests:
                 self.requests[0].noLongerQueued()
+            else:
+                self.terminateConnection = reactor.callLater(self.timeout,
+                                                  self.transport.loseConnection)
         else:
             self.transport.loseConnection()
 
@@ -998,6 +1014,7 @@
     """Factory for HTTP server."""
 
     protocol = HTTPChannel
+    timeout = 60
 
     logPath = None
 





More information about the Twisted-Python mailing list