| 1 | from twisted.internet import iocpreactor |
|---|
| 2 | iocpreactor.install() |
|---|
| 3 | |
|---|
| 4 | from twisted.internet import protocol,defer,reactor |
|---|
| 5 | from twisted.protocols import policies |
|---|
| 6 | from twisted.trial import unittest |
|---|
| 7 | import time |
|---|
| 8 | |
|---|
| 9 | |
|---|
| 10 | smallChunk = b'X' |
|---|
| 11 | smallLen = len(smallChunk) |
|---|
| 12 | packet_size = 130000 |
|---|
| 13 | totalLen = smallLen*packet_size |
|---|
| 14 | maxWrites = 10 |
|---|
| 15 | |
|---|
| 16 | class FireOnClose(policies.ProtocolWrapper): |
|---|
| 17 | """A wrapper around a protocol that makes it fire a deferred when |
|---|
| 18 | connectionLost is called. |
|---|
| 19 | """ |
|---|
| 20 | def connectionLost(self, reason): |
|---|
| 21 | policies.ProtocolWrapper.connectionLost(self, reason) |
|---|
| 22 | self.factory.deferred.callback(None) |
|---|
| 23 | |
|---|
| 24 | |
|---|
| 25 | class FireOnCloseFactory(policies.WrappingFactory): |
|---|
| 26 | protocol = FireOnClose |
|---|
| 27 | |
|---|
| 28 | def __init__(self, wrappedFactory): |
|---|
| 29 | policies.WrappingFactory.__init__(self, wrappedFactory) |
|---|
| 30 | self.deferred = defer.Deferred() |
|---|
| 31 | |
|---|
| 32 | |
|---|
| 33 | |
|---|
| 34 | class SequenceServerProtocol(protocol.Protocol): |
|---|
| 35 | def dataReceived(self,data): |
|---|
| 36 | self.factory.dataLen += len(data) |
|---|
| 37 | # each packet is totalLen in size. Return a single byte to indicate packet received |
|---|
| 38 | while self.factory.dataLen >= totalLen: |
|---|
| 39 | print 'SequenceServerProtocol::dataReceived Echoing back' |
|---|
| 40 | self.transport.write(smallChunk) |
|---|
| 41 | self.factory.dataLen = self.factory.dataLen - totalLen |
|---|
| 42 | |
|---|
| 43 | |
|---|
| 44 | def connectionLost(self,reason): |
|---|
| 45 | print 'Connection lost' |
|---|
| 46 | self.factory.done = 1 |
|---|
| 47 | |
|---|
| 48 | class SequenceServerFactory(protocol.Factory): |
|---|
| 49 | dataLen = 0 |
|---|
| 50 | def buildProtocol(self,addr): |
|---|
| 51 | p = SequenceServerProtocol() |
|---|
| 52 | p.factory = self |
|---|
| 53 | return p |
|---|
| 54 | |
|---|
| 55 | class SequenceClientProtocol(protocol.Protocol): |
|---|
| 56 | def connectionMade(self): |
|---|
| 57 | print 'SequenceClientProtocol::connectionMade Writing' |
|---|
| 58 | self.writeData() |
|---|
| 59 | |
|---|
| 60 | def dataReceived(self,data): |
|---|
| 61 | print 'SequenceClientProtocol::dataReceived Reading %s echo_count %s len %s' % (len(data), self.factory.echo_count, len(data)) |
|---|
| 62 | if len(data) > 1: |
|---|
| 63 | # if there is more than 1 byte then we received more than one response without getting a callback. |
|---|
| 64 | self.factory.normal = False |
|---|
| 65 | self.transport.loseConnection() |
|---|
| 66 | self.factory.error_value = len(data) |
|---|
| 67 | return |
|---|
| 68 | self.factory.echo_count += 1 |
|---|
| 69 | # this calllater should happen immediately as the server side is waiting for 1 second |
|---|
| 70 | reactor.callLater(0, self.gotResponse) |
|---|
| 71 | |
|---|
| 72 | def gotResponse(self): |
|---|
| 73 | print 'Got a response count %s' % self.factory.echo_count |
|---|
| 74 | if self.factory.echo_count >= maxWrites: |
|---|
| 75 | self.factory.normal = True |
|---|
| 76 | self.transport.loseConnection() |
|---|
| 77 | |
|---|
| 78 | def writeData(self): |
|---|
| 79 | print 'SequenceClientProtocol::writeData Write %d data blocks of size %d' % (maxWrites, totalLen) |
|---|
| 80 | for _i in range(0, maxWrites): |
|---|
| 81 | self.transport.write(smallChunk*packet_size) |
|---|
| 82 | |
|---|
| 83 | def connectionLost(self,reason): |
|---|
| 84 | print 'Connection lost client' |
|---|
| 85 | self.factory.done = 1 |
|---|
| 86 | |
|---|
| 87 | class SequenceClientFactory(protocol.ClientFactory): |
|---|
| 88 | def __init__(self): |
|---|
| 89 | self.done = 0 |
|---|
| 90 | self.echo_count = 0 |
|---|
| 91 | self.normal = False |
|---|
| 92 | self.error_value = 0 |
|---|
| 93 | |
|---|
| 94 | def buildProtocol(self,addr): |
|---|
| 95 | p = SequenceClientProtocol() |
|---|
| 96 | p.factory = self |
|---|
| 97 | return p |
|---|
| 98 | |
|---|
| 99 | |
|---|
| 100 | class SequenceTestCase(unittest.TestCase): |
|---|
| 101 | """ Test that the reactor correctly handles IO calls and callLaters's |
|---|
| 102 | Fire 10 packets of 130000 bytes to the server |
|---|
| 103 | The server replies to each packet with a 1 byte acknowledge |
|---|
| 104 | The 'gotResponse' method should be called for each successful ack. |
|---|
| 105 | """ |
|---|
| 106 | |
|---|
| 107 | def testWriter(self): |
|---|
| 108 | print 'Starting' |
|---|
| 109 | f = SequenceServerFactory() |
|---|
| 110 | f.done = 0 |
|---|
| 111 | f.problem = 0 |
|---|
| 112 | wrappedF = FireOnCloseFactory(f) |
|---|
| 113 | p = reactor.listenTCP(0, wrappedF, interface="127.0.0.1") |
|---|
| 114 | self.addCleanup(p.stopListening) |
|---|
| 115 | n = p.getHost().port |
|---|
| 116 | clientF = SequenceClientFactory() |
|---|
| 117 | wrappedClientF = FireOnCloseFactory(clientF) |
|---|
| 118 | reactor.connectTCP("127.0.0.1", n, wrappedClientF) |
|---|
| 119 | |
|---|
| 120 | d = defer.gatherResults([wrappedClientF.deferred]) |
|---|
| 121 | def check(ignored): |
|---|
| 122 | self.failUnless(clientF.normal, |
|---|
| 123 | "client received sequence is abnormal %d responses received without processing" % clientF.error_value) |
|---|
| 124 | self.failUnless(clientF.done, |
|---|
| 125 | "client didn't see connection dropped") |
|---|
| 126 | return d.addCallback(check) |
|---|