root / trunk / twisted / protocols / policies.py

Revision 26919, 18.6 kB (checked in by exarkun, 2 months ago)

Merge iocp-ssl-593-2

Author: exarkun
Reviewer: glyph
Fixes: #593

Add a ProtocolWrapper-based implementation of TLS and use it
to extend the IO Completion Ports reactor with implementations
of IReactorSSL and ITLSTransport.

One of the existing SSL test suites is also changed to remove
one of its base classes (and therefore some of the tests it
was running); these tests had nothing to do with SSL, they were
just redundant TCP tests.

Line 
1 # -*- test-case-name: twisted.test.test_policies -*-
2 # Copyright (c) 2001-2009 Twisted Matrix Laboratories.
3 # See LICENSE for details.
4
5 """
6 Resource limiting policies.
7
8 @seealso: See also L{twisted.protocols.htb} for rate limiting.
9 """
10
11 # system imports
12 import sys, operator
13
14 from zope.interface import directlyProvides, providedBy
15
16 # twisted imports
17 from twisted.internet.protocol import ServerFactory, Protocol, ClientFactory
18 from twisted.internet import error
19 from twisted.python import log
20
21
22 class ProtocolWrapper(Protocol):
23     """
24     Wraps protocol instances and acts as their transport as well.
25
26     @ivar wrappedProtocol: An L{IProtocol} provider to which L{IProtocol}
27         method calls onto this L{ProtocolWrapper} will be proxied.
28
29     @ivar factory: The L{WrappingFactory} which created this
30         L{ProtocolWrapper}.
31     """
32
33     disconnecting = 0
34
35     def __init__(self, factory, wrappedProtocol):
36         self.wrappedProtocol = wrappedProtocol
37         self.factory = factory
38
39     def makeConnection(self, transport):
40         """
41         When a connection is made, register this wrapper with its factory,
42         save the real transport, and connect the wrapped protocol to this
43         L{ProtocolWrapper} to intercept any transport calls it makes.
44         """
45         directlyProvides(self, providedBy(transport))
46         Protocol.makeConnection(self, transport)
47         self.factory.registerProtocol(self)
48         self.wrappedProtocol.makeConnection(self)
49
50     # Transport relaying
51
52     def write(self, data):
53         self.transport.write(data)
54
55     def writeSequence(self, data):
56         self.transport.writeSequence(data)
57
58     def loseConnection(self):
59         self.disconnecting = 1
60         self.transport.loseConnection()
61
62     def getPeer(self):
63         return self.transport.getPeer()
64
65     def getHost(self):
66         return self.transport.getHost()
67
68     def registerProducer(self, producer, streaming):
69         self.transport.registerProducer(producer, streaming)
70
71     def unregisterProducer(self):
72         self.transport.unregisterProducer()
73
74     def stopConsuming(self):
75         self.transport.stopConsuming()
76
77     def __getattr__(self, name):
78         return getattr(self.transport, name)
79
80     # Protocol relaying
81
82     def dataReceived(self, data):
83         self.wrappedProtocol.dataReceived(data)
84
85     def connectionLost(self, reason):
86         self.factory.unregisterProtocol(self)
87         self.wrappedProtocol.connectionLost(reason)
88
89
90 class WrappingFactory(ClientFactory):
91     """Wraps a factory and its protocols, and keeps track of them."""
92
93     protocol = ProtocolWrapper
94
95     def __init__(self, wrappedFactory):
96         self.wrappedFactory = wrappedFactory
97         self.protocols = {}
98
99     def doStart(self):
100         self.wrappedFactory.doStart()
101         ClientFactory.doStart(self)
102
103     def doStop(self):
104         self.wrappedFactory.doStop()
105         ClientFactory.doStop(self)
106
107     def startedConnecting(self, connector):
108         self.wrappedFactory.startedConnecting(connector)
109
110     def clientConnectionFailed(self, connector, reason):
111         self.wrappedFactory.clientConnectionFailed(connector, reason)
112
113     def clientConnectionLost(self, connector, reason):
114         self.wrappedFactory.clientConnectionLost(connector, reason)
115
116     def buildProtocol(self, addr):
117         return self.protocol(self, self.wrappedFactory.buildProtocol(addr))
118
119     def registerProtocol(self, p):
120         """Called by protocol to register itself."""
121         self.protocols[p] = 1
122
123     def unregisterProtocol(self, p):
124         """Called by protocols when they go away."""
125         del self.protocols[p]
126
127
128 class ThrottlingProtocol(ProtocolWrapper):
129     """Protocol for ThrottlingFactory."""
130
131     # wrap API for tracking bandwidth
132
133     def write(self, data):
134         self.factory.registerWritten(len(data))
135         ProtocolWrapper.write(self, data)
136
137     def writeSequence(self, seq):
138         self.factory.registerWritten(reduce(operator.add, map(len, seq)))
139         ProtocolWrapper.writeSequence(self, seq)
140
141     def dataReceived(self, data):
142         self.factory.registerRead(len(data))
143         ProtocolWrapper.dataReceived(self, data)
144
145     def registerProducer(self, producer, streaming):
146         self.producer = producer
147         ProtocolWrapper.registerProducer(self, producer, streaming)
148
149     def unregisterProducer(self):
150         del self.producer
151         ProtocolWrapper.unregisterProducer(self)
152
153
154     def throttleReads(self):
155         self.transport.pauseProducing()
156
157     def unthrottleReads(self):
158         self.transport.resumeProducing()
159
160     def throttleWrites(self):
161         if hasattr(self, "producer"):
162             self.producer.pauseProducing()
163
164     def unthrottleWrites(self):
165         if hasattr(self, "producer"):
166             self.producer.resumeProducing()
167
168
169 class ThrottlingFactory(WrappingFactory):
170     """
171     Throttles bandwidth and number of connections.
172
173     Write bandwidth will only be throttled if there is a producer
174     registered.
175     """
176
177     protocol = ThrottlingProtocol
178
179     def __init__(self, wrappedFactory, maxConnectionCount=sys.maxint,
180                  readLimit=None, writeLimit=None):
181         WrappingFactory.__init__(self, wrappedFactory)
182         self.connectionCount = 0
183         self.maxConnectionCount = maxConnectionCount
184         self.readLimit = readLimit # max bytes we should read per second
185         self.writeLimit = writeLimit # max bytes we should write per second
186         self.readThisSecond = 0
187         self.writtenThisSecond = 0
188         self.unthrottleReadsID = None
189         self.checkReadBandwidthID = None
190         self.unthrottleWritesID = None
191         self.checkWriteBandwidthID = None
192
193
194     def callLater(self, period, func):
195         """
196         Wrapper around L{reactor.callLater} for test purpose.
197         """
198         from twisted.internet import reactor
199         return reactor.callLater(period, func)
200
201
202     def registerWritten(self, length):
203         """
204         Called by protocol to tell us more bytes were written.
205         """
206         self.writtenThisSecond += length
207
208
209     def registerRead(self, length):
210         """
211         Called by protocol to tell us more bytes were read.
212         """
213         self.readThisSecond += length
214
215
216     def checkReadBandwidth(self):
217         """
218         Checks if we've passed bandwidth limits.
219         """
220         if self.readThisSecond > self.readLimit:
221             self.throttleReads()
222             throttleTime = (float(self.readThisSecond) / self.readLimit) - 1.0
223             self.unthrottleReadsID = self.callLater(throttleTime,
224                                                     self.unthrottleReads)
225         self.readThisSecond = 0
226         self.checkReadBandwidthID = self.callLater(1, self.checkReadBandwidth)
227
228
229     def checkWriteBandwidth(self):
230         if self.writtenThisSecond > self.writeLimit:
231             self.throttleWrites()
232             throttleTime = (float(self.writtenThisSecond) / self.writeLimit) - 1.0
233             self.unthrottleWritesID = self.callLater(throttleTime,
234                                                         self.unthrottleWrites)
235         # reset for next round
236         self.writtenThisSecond = 0
237         self.checkWriteBandwidthID = self.callLater(1, self.checkWriteBandwidth)
238
239
240     def throttleReads(self):
241         """
242         Throttle reads on all protocols.
243         """
244         log.msg("Throttling reads on %s" % self)
245         for p in self.protocols.keys():
246             p.throttleReads()
247
248
249     def unthrottleReads(self):
250         """
251         Stop throttling reads on all protocols.
252         """
253         self.unthrottleReadsID = None
254         log.msg("Stopped throttling reads on %s" % self)
255         for p in self.protocols.keys():
256             p.unthrottleReads()
257
258
259     def throttleWrites(self):
260         """
261         Throttle writes on all protocols.
262         """
263         log.msg("Throttling writes on %s" % self)
264         for p in self.protocols.keys():
265             p.throttleWrites()
266
267
268     def unthrottleWrites(self):
269         """
270         Stop throttling writes on all protocols.
271         """
272         self.unthrottleWritesID = None
273         log.msg("Stopped throttling writes on %s" % self)
274         for p in self.protocols.keys():
275             p.unthrottleWrites()
276
277
278     def buildProtocol(self, addr):
279         if self.connectionCount == 0:
280             if self.readLimit is not None:
281                 self.checkReadBandwidth()
282             if self.writeLimit is not None:
283                 self.checkWriteBandwidth()
284
285         if self.connectionCount < self.maxConnectionCount:
286             self.connectionCount += 1
287             return WrappingFactory.buildProtocol(self, addr)
288         else:
289             log.msg("Max connection count reached!")
290             return None
291
292
293     def unregisterProtocol(self, p):
294         WrappingFactory.unregisterProtocol(self, p)
295         self.connectionCount -= 1
296         if self.connectionCount == 0:
297             if self.unthrottleReadsID is not None:
298                 self.unthrottleReadsID.cancel()
299             if self.checkReadBandwidthID is not None:
300                 self.checkReadBandwidthID.cancel()
301             if self.unthrottleWritesID is not None:
302                 self.unthrottleWritesID.cancel()
303             if self.checkWriteBandwidthID is not None:
304                 self.checkWriteBandwidthID.cancel()
305
306
307
308 class SpewingProtocol(ProtocolWrapper):
309     def dataReceived(self, data):
310         log.msg("Received: %r" % data)
311         ProtocolWrapper.dataReceived(self,data)
312
313     def write(self, data):
314         log.msg("Sending: %r" % data)
315         ProtocolWrapper.write(self,data)
316
317
318
319 class SpewingFactory(WrappingFactory):
320     protocol = SpewingProtocol
321
322
323
324 class LimitConnectionsByPeer(WrappingFactory):
325
326     maxConnectionsPerPeer = 5
327
328     def startFactory(self):
329         self.peerConnections = {}
330
331     def buildProtocol(self, addr):
332         peerHost = addr[0]
333         connectionCount = self.peerConnections.get(peerHost, 0)
334         if connectionCount >= self.maxConnectionsPerPeer:
335             return None
336         self.peerConnections[peerHost] = connectionCount + 1
337         return WrappingFactory.buildProtocol(self, addr)
338
339     def unregisterProtocol(self, p):
340         peerHost = p.getPeer()[1]
341         self.peerConnections[peerHost] -= 1
342         if self.peerConnections[peerHost] == 0:
343             del self.peerConnections[peerHost]
344
345
346 class LimitTotalConnectionsFactory(ServerFactory):
347     """
348     Factory that limits the number of simultaneous connections.
349
350     @type connectionCount: C{int}
351     @ivar connectionCount: number of current connections.
352     @type connectionLimit: C{int} or C{None}
353     @cvar connectionLimit: maximum number of connections.
354     @type overflowProtocol: L{Protocol} or C{None}
355     @cvar overflowProtocol: Protocol to use for new connections when
356         connectionLimit is exceeded.  If C{None} (the default value), excess
357         connections will be closed immediately.
358     """
359     connectionCount = 0
360     connectionLimit = None
361     overflowProtocol = None
362
363     def buildProtocol(self, addr):
364         if (self.connectionLimit is None or
365             self.connectionCount < self.connectionLimit):
366                 # Build the normal protocol
367                 wrappedProtocol = self.protocol()
368         elif self.overflowProtocol is None:
369             # Just drop the connection
370             return None
371         else:
372             # Too many connections, so build the overflow protocol
373             wrappedProtocol = self.overflowProtocol()
374
375         wrappedProtocol.factory = self
376         protocol = ProtocolWrapper(self, wrappedProtocol)
377         self.connectionCount += 1
378         return protocol
379
380     def registerProtocol(self, p):
381         pass
382
383     def unregisterProtocol(self, p):
384         self.connectionCount -= 1
385
386
387
388 class TimeoutProtocol(ProtocolWrapper):
389     """
390     Protocol that automatically disconnects when the connection is idle.
391     """
392
393     def __init__(self, factory, wrappedProtocol, timeoutPeriod):
394         """
395         Constructor.
396
397         @param factory: An L{IFactory}.
398         @param wrappedProtocol: A L{Protocol} to wrapp.
399         @param timeoutPeriod: Number of seconds to wait for activity before
400             timing out.
401         """
402         ProtocolWrapper.__init__(self, factory, wrappedProtocol)
403         self.timeoutCall = None
404         self.setTimeout(timeoutPeriod)
405
406
407     def setTimeout(self, timeoutPeriod=None):
408         """
409         Set a timeout.
410
411         This will cancel any existing timeouts.
412
413         @param timeoutPeriod: If not C{None}, change the timeout period.
414             Otherwise, use the existing value.
415         """
416         self.cancelTimeout()
417         if timeoutPeriod is not None:
418             self.timeoutPeriod = timeoutPeriod
419         self.timeoutCall = self.factory.callLater(self.timeoutPeriod, self.timeoutFunc)
420
421
422     def cancelTimeout(self):
423         """
424         Cancel the timeout.
425
426         If the timeout was already cancelled, this does nothing.
427         """
428         if self.timeoutCall:
429             try:
430                 self.timeoutCall.cancel()
431             except error.AlreadyCalled:
432                 pass
433             self.timeoutCall = None
434
435
436     def resetTimeout(self):
437         """
438         Reset the timeout, usually because some activity just happened.
439         """
440         if self.timeoutCall:
441             self.timeoutCall.reset(self.timeoutPeriod)
442
443
444     def write(self, data):
445         self.resetTimeout()
446         ProtocolWrapper.write(self, data)
447
448
449     def writeSequence(self, seq):
450         self.resetTimeout()
451         ProtocolWrapper.writeSequence(self, seq)
452
453
454     def dataReceived(self, data):
455         self.resetTimeout()
456         ProtocolWrapper.dataReceived(self, data)
457
458
459     def connectionLost(self, reason):
460         self.cancelTimeout()
461         ProtocolWrapper.connectionLost(self, reason)
462
463
464     def timeoutFunc(self):
465         """
466         This method is called when the timeout is triggered.
467
468         By default it calls L{loseConnection}.  Override this if you want
469         something else to happen.
470         """
471         self.loseConnection()
472
473
474
475 class TimeoutFactory(WrappingFactory):
476     """
477     Factory for TimeoutWrapper.
478     """
479     protocol = TimeoutProtocol
480
481
482     def __init__(self, wrappedFactory, timeoutPeriod=30*60):
483         self.timeoutPeriod = timeoutPeriod
484         WrappingFactory.__init__(self, wrappedFactory)
485
486
487     def buildProtocol(self, addr):
488         return self.protocol(self, self.wrappedFactory.buildProtocol(addr),
489                              timeoutPeriod=self.timeoutPeriod)
490
491
492     def callLater(self, period, func):
493         """
494         Wrapper around L{reactor.callLater} for test purpose.
495         """
496         from twisted.internet import reactor
497         return reactor.callLater(period, func)
498
499
500
501 class TrafficLoggingProtocol(ProtocolWrapper):
502
503     def __init__(self, factory, wrappedProtocol, logfile, lengthLimit=None,
504                  number=0):
505         """
506         @param factory: factory which created this protocol.
507         @type factory: C{protocol.Factory}.
508         @param wrappedProtocol: the underlying protocol.
509         @type wrappedProtocol: C{protocol.Protocol}.
510         @param logfile: file opened for writing used to write log messages.
511         @type logfile: C{file}
512         @param lengthLimit: maximum size of the datareceived logged.
513         @type lengthLimit: C{int}
514         @param number: identifier of the connection.
515         @type number: C{int}.
516         """
517         ProtocolWrapper.__init__(self, factory, wrappedProtocol)
518         self.logfile = logfile
519         self.lengthLimit = lengthLimit
520         self._number = number
521
522
523     def _log(self, line):
524         self.logfile.write(line + '\n')
525         self.logfile.flush()
526
527
528     def _mungeData(self, data):
529         if self.lengthLimit and len(data) > self.lengthLimit:
530             data = data[:self.lengthLimit - 12] + '<... elided>'
531         return data
532
533
534     # IProtocol
535     def connectionMade(self):
536         self._log('*')
537         return ProtocolWrapper.connectionMade(self)
538
539
540     def dataReceived(self, data):
541         self._log('C %d: %r' % (self._number, self._mungeData(data)))
542         return ProtocolWrapper.dataReceived(self, data)
543
544
545     def connectionLost(self, reason):
546         self._log('C %d: %r' % (self._number, reason))
547         return ProtocolWrapper.connectionLost(self, reason)
548
549
550     # ITransport
551     def write(self, data):
552         self._log('S %d: %r' % (self._number, self._mungeData(data)))
553         return ProtocolWrapper.write(self, data)
554
555
556     def writeSequence(self, iovec):
557         self._log('SV %d: %r' % (self._number, [self._mungeData(d) for d in iovec]))
558         return ProtocolWrapper.writeSequence(self, iovec)
559
560
561     def loseConnection(self):
562         self._log('S %d: *' % (self._number,))
563         return ProtocolWrapper.loseConnection(self)
564
565
566
567 class TrafficLoggingFactory(WrappingFactory):
568     protocol = TrafficLoggingProtocol
569
570     _counter = 0
571
572     def __init__(self, wrappedFactory, logfilePrefix, lengthLimit=None):
573         self.logfilePrefix = logfilePrefix
574         self.lengthLimit = lengthLimit
575         WrappingFactory.__init__(self, wrappedFactory)
576
577
578     def open(self, name):
579         return file(name, 'w')
580
581
582     def buildProtocol(self, addr):
583         self._counter += 1
584         logfile = self.open(self.logfilePrefix + '-' + str(self._counter))
585         return self.protocol(self, self.wrappedFactory.buildProtocol(addr),
586                              logfile, self.lengthLimit, self._counter)
587
588
589     def resetCounter(self):
590         """
591         Reset the value of the counter used to identify connections.
592         """
593         self._counter = 0
594
595
596
597 class TimeoutMixin:
598     """Mixin for protocols which wish to timeout connections
599
600     @cvar timeOut: The number of seconds after which to timeout the connection.
601     """
602     timeOut = None
603
604     __timeoutCall = None
605
606     def callLater(self, period, func):
607         from twisted.internet import reactor
608         return reactor.callLater(period, func)
609
610
611     def resetTimeout(self):
612         """Reset the timeout count down"""
613         if self.__timeoutCall is not None and self.timeOut is not None:
614             self.__timeoutCall.reset(self.timeOut)
615
616     def setTimeout(self, period):
617         """Change the timeout period
618
619         @type period: C{int} or C{NoneType}
620         @param period: The period, in seconds, to change the timeout to, or
621         C{None} to disable the timeout.
622         """
623         prev = self.timeOut
624         self.timeOut = period
625
626         if self.__timeoutCall is not None:
627             if period is None:
628                 self.__timeoutCall.cancel()
629                 self.__timeoutCall = None
630             else:
631                 self.__timeoutCall.reset(period)
632         elif period is not None:
633             self.__timeoutCall = self.callLater(period, self.__timedOut)
634
635         return prev
636
637     def __timedOut(self):
638         self.__timeoutCall = None
639         self.timeoutConnection()
640
641     def timeoutConnection(self):
642         """Called when the connection times out.
643         Override to define behavior other than dropping the connection.
644         """
645         self.transport.loseConnection()
Note: See TracBrowser for help on using the browser.