Ticket #6217: test_iocp_sequence.py

File test_iocp_sequence.py, 4.5 KB (added by BrianMatthews, 2 years ago)
Line 
1from twisted.internet import iocpreactor
2iocpreactor.install()
3
4from twisted.internet import protocol,defer,reactor
5from twisted.protocols import policies
6from twisted.trial import unittest
7import time
8
9
10smallChunk = b'X'
11smallLen = len(smallChunk)
12packet_size = 130000
13totalLen = smallLen*packet_size
14maxWrites = 10
15
16class FireOnClose(policies.ProtocolWrapper):
17    """A wrapper around a protocol that makes it fire a deferred when
18    connectionLost is called.
19    """
20    def connectionLost(self, reason):
21        policies.ProtocolWrapper.connectionLost(self, reason)
22        self.factory.deferred.callback(None)
23
24
25class FireOnCloseFactory(policies.WrappingFactory):
26    protocol = FireOnClose
27
28    def __init__(self, wrappedFactory):
29        policies.WrappingFactory.__init__(self, wrappedFactory)
30        self.deferred = defer.Deferred()
31
32
33
34class SequenceServerProtocol(protocol.Protocol):
35    def dataReceived(self,data):
36        self.factory.dataLen += len(data)
37        # each packet is totalLen in size. Return a single byte to indicate packet received
38        while self.factory.dataLen >= totalLen:
39            print 'SequenceServerProtocol::dataReceived Echoing back'
40            self.transport.write(smallChunk)
41            self.factory.dataLen = self.factory.dataLen - totalLen
42           
43           
44    def connectionLost(self,reason):
45        print 'Connection lost'
46        self.factory.done = 1
47
48class SequenceServerFactory(protocol.Factory):
49    dataLen = 0
50    def buildProtocol(self,addr):
51        p = SequenceServerProtocol()
52        p.factory = self
53        return p
54
55class SequenceClientProtocol(protocol.Protocol):
56    def connectionMade(self):
57        print 'SequenceClientProtocol::connectionMade Writing'
58        self.writeData()
59       
60    def dataReceived(self,data):
61        print 'SequenceClientProtocol::dataReceived Reading %s echo_count %s len %s' % (len(data), self.factory.echo_count, len(data))
62        if len(data) > 1:
63            # if there is more than 1 byte then we received more than one response without getting a callback.
64            self.factory.normal = False
65            self.transport.loseConnection()
66            self.factory.error_value = len(data)
67            return
68        self.factory.echo_count += 1
69        # this calllater should happen immediately as the server side is waiting for 1 second
70        reactor.callLater(0, self.gotResponse)
71           
72    def gotResponse(self):
73        print 'Got a response count %s' % self.factory.echo_count
74        if self.factory.echo_count >= maxWrites:
75            self.factory.normal = True
76            self.transport.loseConnection()
77           
78    def writeData(self):
79        print 'SequenceClientProtocol::writeData Write %d data blocks of size %d' % (maxWrites, totalLen)
80        for _i in range(0, maxWrites):
81            self.transport.write(smallChunk*packet_size)
82
83    def connectionLost(self,reason):
84        print 'Connection lost client'
85        self.factory.done = 1
86
87class SequenceClientFactory(protocol.ClientFactory):
88    def __init__(self):
89        self.done = 0
90        self.echo_count = 0
91        self.normal = False
92        self.error_value = 0       
93
94    def buildProtocol(self,addr):
95        p = SequenceClientProtocol()
96        p.factory = self
97        return p
98
99
100class SequenceTestCase(unittest.TestCase):
101    """ Test that the reactor correctly handles IO calls and callLaters's
102        Fire 10 packets of 130000 bytes to the server
103        The server replies to each packet with a 1 byte acknowledge
104        The 'gotResponse' method should be called for each successful ack.
105    """
106
107    def testWriter(self):
108        print 'Starting'
109        f = SequenceServerFactory()
110        f.done = 0
111        f.problem = 0
112        wrappedF = FireOnCloseFactory(f)
113        p = reactor.listenTCP(0, wrappedF, interface="127.0.0.1")
114        self.addCleanup(p.stopListening)
115        n = p.getHost().port
116        clientF = SequenceClientFactory()
117        wrappedClientF = FireOnCloseFactory(clientF)
118        reactor.connectTCP("127.0.0.1", n, wrappedClientF)
119
120        d = defer.gatherResults([wrappedClientF.deferred])
121        def check(ignored):
122            self.failUnless(clientF.normal,
123                            "client received sequence is abnormal %d responses received without processing" % clientF.error_value)
124            self.failUnless(clientF.done,
125                            "client didn't see connection dropped")
126        return d.addCallback(check)
127