Ticket #5562: 5562_test3.py

File 5562_test3.py, 3.4 KB (added by dynamicgl, 22 months ago)
Line 
1from twisted.internet import iocpreactor
2iocpreactor.install()
3from twisted.internet import protocol,defer,reactor
4from twisted.protocols import policies
5from twisted.trial import unittest
6
7smallChunk = b'X'
8smallLen = len(smallChunk)
9ops = 2*1024*1024
10totalLen = smallLen*ops
11
12class FireOnClose(policies.ProtocolWrapper):
13    """A wrapper around a protocol that makes it fire a deferred when
14    connectionLost is called.
15    """
16    def connectionLost(self, reason):
17        policies.ProtocolWrapper.connectionLost(self, reason)
18        self.factory.deferred.callback(None)
19
20
21class FireOnCloseFactory(policies.WrappingFactory):
22    protocol = FireOnClose
23
24    def __init__(self, wrappedFactory):
25        policies.WrappingFactory.__init__(self, wrappedFactory)
26        self.deferred = defer.Deferred()
27
28
29
30class LargeProtocol(protocol.Protocol):
31    def dataReceived(self,data):
32        self.transport.write(data)
33        self.factory.dataLen += len(data)
34
35    def connectionLost(self,reason):
36        self.factory.done = 1
37
38class LargeServerFactory(protocol.Factory):
39    dataLen = 0
40    maxLen = 0
41    def buildProtocol(self,addr):
42        p = LargeProtocol()
43        p.factory = self
44        return p
45
46class LargeClientProtocol(protocol.Protocol):
47    def connectionMade(self):
48        reactor.callLater(1,self.transport.write,smallChunk*ops)
49        self.checkd = None
50       
51    def dataReceived(self,data):
52        self.factory.dataBuffer += data
53        if not self.checkd:
54            self.checkd = reactor.callLater(1, self.extraCheck)
55           
56           
57
58    def connectionLost(self,reason):
59        self.factory.done = 1
60        if self.checkd:
61            self.checkd.cancel()
62            self.checkd = None
63
64    def extraCheck(self):
65        if len(self.factory.dataBuffer) == totalLen:
66            self.factory.normal = True
67            self.transport.loseConnection()
68        elif len(self.factory.dataBuffer) > totalLen:
69            self.factory.normal = False
70            self.transport.loseConnection()
71        else:
72            pass
73        self.checkd = None
74
75
76class LargeClientFactory(protocol.ClientFactory):
77    def __init__(self):
78        self.done = 0
79        self.dataBuffer = ''
80        self.normal = False
81       
82    def buildProtocol(self,addr):
83        p = LargeClientProtocol()
84        p.factory = self
85        return p
86
87
88class LargeTestCase(unittest.TestCase):
89    """Test that buffering large amounts of data works.
90    """
91
92    def testWriter(self):
93        f = LargeServerFactory()
94        f.done = 0
95        f.problem = 0
96        f.maxLen = totalLen
97        wrappedF = FireOnCloseFactory(f)
98        p = reactor.listenTCP(12345, wrappedF, interface="127.0.0.1")
99        self.addCleanup(p.stopListening)
100        n = p.getHost().port
101        #n = 12345
102        clientF = LargeClientFactory()
103        wrappedClientF = FireOnCloseFactory(clientF)
104        reactor.connectTCP("127.0.0.1", n, wrappedClientF)
105
106        d = defer.gatherResults([wrappedClientF.deferred])
107        def check(ignored):
108            self.failUnless(clientF.normal,
109                            "client received data is abnormal "
110                            "(%d != %d)" % (len(clientF.dataBuffer), totalLen))
111            self.failUnless(clientF.done,
112                            "client didn't see connection dropped")
113        return d.addCallback(check)
114