Ticket #5562: test_tcp.py

File test_tcp.py, 65.8 KB (added by BrianMatthews, 22 months ago)
Line 
1# Copyright (c) Twisted Matrix Laboratories.
2# See LICENSE for details.
3
4"""
5Tests for implementations of L{IReactorTCP} and the TCP parts of
6L{IReactorSocket}.
7"""
8
9__metaclass__ = type
10
11import socket, errno
12
13from zope.interface import implements
14
15from twisted.python.runtime import platform
16from twisted.python.failure import Failure
17from twisted.python import log
18
19from twisted.trial.unittest import SkipTest, TestCase
20from twisted.internet.test.reactormixins import ReactorBuilder, EndpointCreator
21from twisted.internet.test.reactormixins import ConnectableProtocol
22from twisted.internet.test.reactormixins import runProtocolsWithReactor
23from twisted.internet.error import ConnectionLost, UserError, ConnectionRefusedError
24from twisted.internet.error import ConnectionDone, ConnectionAborted
25from twisted.internet.interfaces import (
26    ILoggingContext, IConnector, IReactorFDSet, IReactorSocket)
27from twisted.internet.address import IPv4Address, IPv6Address
28from twisted.internet.defer import (
29    Deferred, DeferredList, maybeDeferred, gatherResults)
30from twisted.internet.endpoints import (
31    TCP4ServerEndpoint, TCP4ClientEndpoint)
32from twisted.internet.protocol import ServerFactory, ClientFactory, Protocol
33from twisted.internet.interfaces import (
34    IPushProducer, IPullProducer, IHalfCloseableProtocol)
35from twisted.internet.tcp import Connection, Server, _resolveIPv6
36
37from twisted.internet.test.connectionmixins import (
38    LogObserverMixin, ConnectionTestsMixin, TCPClientTestsMixin, findFreePort)
39from twisted.internet.test.test_core import ObjectModelIntegrationMixin
40from twisted.test.test_tcp import MyClientFactory, MyServerFactory
41from twisted.test.test_tcp import ClosingFactory, ClientStartStopFactory
42
43try:
44    from OpenSSL import SSL
45except ImportError:
46    useSSL = False
47else:
48    from twisted.internet.ssl import ClientContextFactory
49    useSSL = True
50
51try:
52    socket.socket(socket.AF_INET6, socket.SOCK_STREAM).close()
53except socket.error, e:
54    ipv6Skip = str(e)
55else:
56    ipv6Skip = None
57
58
59
60if platform.isWindows():
61    from twisted.internet.test import _win32ifaces
62    getLinkLocalIPv6Addresses = _win32ifaces.win32GetLinkLocalIPv6Addresses
63else:
64    try:
65        from twisted.internet.test import _posixifaces
66    except ImportError:
67        getLinkLocalIPv6Addresses = lambda: []
68    else:
69        getLinkLocalIPv6Addresses = _posixifaces.posixGetLinkLocalIPv6Addresses
70
71
72
73def getLinkLocalIPv6Address():
74    """
75    Find and return a configured link local IPv6 address including a scope
76    identifier using the % separation syntax.  If the system has no link local
77    IPv6 addresses, raise L{SkipTest} instead.
78
79    @raise SkipTest: if no link local address can be found or if the
80        C{netifaces} module is not available.
81
82    @return: a C{str} giving the address
83    """
84    addresses = getLinkLocalIPv6Addresses()
85    if addresses:
86        return addresses[0]
87    raise SkipTest("Link local IPv6 address unavailable")
88
89
90
91def connect(client, (host, port)):
92    if '%' in host or ':' in host:
93        address = socket.getaddrinfo(host, port)[0][4]
94    else:
95        address = (host, port)
96    client.connect(address)
97
98
99
100class FakeSocket(object):
101    """
102    A fake for L{socket.socket} objects.
103
104    @ivar data: A C{str} giving the data which will be returned from
105        L{FakeSocket.recv}.
106
107    @ivar sendBuffer: A C{list} of the objects passed to L{FakeSocket.send}.
108    """
109    def __init__(self, data):
110        self.data = data
111        self.sendBuffer = []
112
113    def setblocking(self, blocking):
114        self.blocking = blocking
115
116    def recv(self, size):
117        return self.data
118
119    def send(self, bytes):
120        """
121        I{Send} all of C{bytes} by accumulating it into C{self.sendBuffer}.
122
123        @return: The length of C{bytes}, indicating all the data has been
124            accepted.
125        """
126        self.sendBuffer.append(bytes)
127        return len(bytes)
128
129
130    def shutdown(self, how):
131        """
132        Shutdown is not implemented.  The method is provided since real sockets
133        have it and some code expects it.  No behavior of L{FakeSocket} is
134        affected by a call to it.
135        """
136
137
138    def close(self):
139        """
140        Close is not implemented.  The method is provided since real sockets
141        have it and some code expects it.  No behavior of L{FakeSocket} is
142        affected by a call to it.
143        """
144
145
146    def setsockopt(self, *args):
147        """
148        Setsockopt is not implemented.  The method is provided since
149        real sockets have it and some code expects it.  No behavior of
150        L{FakeSocket} is affected by a call to it.
151        """
152
153
154    def fileno(self):
155        """
156        Return a fake file descriptor.  If actually used, this will have no
157        connection to this L{FakeSocket} and will probably cause surprising
158        results.
159        """
160        return 1
161
162
163
164class TestFakeSocket(TestCase):
165    """
166    Test that the FakeSocket can be used by the doRead method of L{Connection}
167    """
168
169    def test_blocking(self):
170        skt = FakeSocket("someData")
171        skt.setblocking(0)
172        self.assertEqual(skt.blocking, 0)
173
174
175    def test_recv(self):
176        skt = FakeSocket("someData")
177        self.assertEqual(skt.recv(10), "someData")
178
179
180    def test_send(self):
181        """
182        L{FakeSocket.send} accepts the entire string passed to it, adds it to
183        its send buffer, and returns its length.
184        """
185        skt = FakeSocket("")
186        count = skt.send("foo")
187        self.assertEqual(count, 3)
188        self.assertEqual(skt.sendBuffer, ["foo"])
189
190
191
192class FakeProtocol(Protocol):
193    """
194    An L{IProtocol} that returns a value from its dataReceived method.
195    """
196    def dataReceived(self, data):
197        """
198        Return something other than C{None} to trigger a deprecation warning for
199        that behavior.
200        """
201        return ()
202
203
204
205class _FakeFDSetReactor(object):
206    """
207    A no-op implementation of L{IReactorFDSet}, which ignores all adds and
208    removes.
209    """
210    implements(IReactorFDSet)
211
212    addReader = addWriter = removeReader = removeWriter = (
213        lambda self, desc: None)
214
215
216
217class TCPServerTests(TestCase):
218    """
219    Whitebox tests for L{twisted.internet.tcp.Server}.
220    """
221    def setUp(self):
222        self.reactor = _FakeFDSetReactor()
223        class FakePort(object):
224            _realPortNumber = 3
225        self.skt = FakeSocket("")
226        self.protocol = Protocol()
227        self.server = Server(
228            self.skt, self.protocol, ("", 0), FakePort(), None, self.reactor)
229
230
231    def test_writeAfterDisconnect(self):
232        """
233        L{Server.write} discards bytes passed to it if called after it has lost
234        its connection.
235        """
236        self.server.connectionLost(
237            Failure(Exception("Simulated lost connection")))
238        self.server.write("hello world")
239        self.assertEqual(self.skt.sendBuffer, [])
240
241
242    def test_writeAfteDisconnectAfterTLS(self):
243        """
244        L{Server.write} discards bytes passed to it if called after it has lost
245        its connection when the connection had started TLS.
246        """
247        self.server.TLS = True
248        self.test_writeAfterDisconnect()
249
250
251    def test_writeSequenceAfterDisconnect(self):
252        """
253        L{Server.writeSequence} discards bytes passed to it if called after it
254        has lost its connection.
255        """
256        self.server.connectionLost(
257            Failure(Exception("Simulated lost connection")))
258        self.server.writeSequence(["hello world"])
259        self.assertEqual(self.skt.sendBuffer, [])
260
261
262    def test_writeSequenceAfteDisconnectAfterTLS(self):
263        """
264        L{Server.writeSequence} discards bytes passed to it if called after it
265        has lost its connection when the connection had started TLS.
266        """
267        self.server.TLS = True
268        self.test_writeSequenceAfterDisconnect()
269
270
271
272class TCPConnectionTests(TestCase):
273    """
274    Whitebox tests for L{twisted.internet.tcp.Connection}.
275    """
276    def test_doReadWarningIsRaised(self):
277        """
278        When an L{IProtocol} implementation that returns a value from its
279        C{dataReceived} method, a deprecated warning is emitted.
280        """
281        skt = FakeSocket("someData")
282        protocol = FakeProtocol()
283        conn = Connection(skt, protocol)
284        conn.doRead()
285        warnings = self.flushWarnings([FakeProtocol.dataReceived])
286        self.assertEqual(warnings[0]['category'], DeprecationWarning)
287        self.assertEqual(
288            warnings[0]["message"],
289            "Returning a value other than None from "
290            "twisted.internet.test.test_tcp.FakeProtocol.dataReceived "
291            "is deprecated since Twisted 11.0.0.")
292        self.assertEqual(len(warnings), 1)
293
294
295    def test_noTLSBeforeStartTLS(self):
296        """
297        The C{TLS} attribute of a L{Connection} instance is C{False} before
298        L{Connection.startTLS} is called.
299        """
300        skt = FakeSocket("")
301        protocol = FakeProtocol()
302        conn = Connection(skt, protocol)
303        self.assertFalse(conn.TLS)
304
305
306    def test_tlsAfterStartTLS(self):
307        """
308        The C{TLS} attribute of a L{Connection} instance is C{True} after
309        L{Connection.startTLS} is called.
310        """
311        skt = FakeSocket("")
312        protocol = FakeProtocol()
313        conn = Connection(skt, protocol, reactor=_FakeFDSetReactor())
314        conn._tlsClientDefault = True
315        conn.startTLS(ClientContextFactory(), True)
316        self.assertTrue(conn.TLS)
317    if not useSSL:
318        test_tlsAfterStartTLS.skip = "No SSL support available"
319
320
321
322class TCPCreator(EndpointCreator):
323    """
324    Create IPv4 TCP endpoints for L{runProtocolsWithReactor}-based tests.
325    """
326
327    interface = "127.0.0.1"
328
329    def server(self, reactor):
330        """
331        Create a server-side TCP endpoint.
332        """
333        return TCP4ServerEndpoint(reactor, 0, interface=self.interface)
334
335
336    def client(self, reactor, serverAddress):
337        """
338        Create a client end point that will connect to the given address.
339
340        @type serverAddress: L{IPv4Address}
341        """
342        return TCP4ClientEndpoint(reactor, self.interface, serverAddress.port)
343
344
345
346class TCP6Creator(TCPCreator):
347    """
348    Create IPv6 TCP endpoints for
349    C{ReactorBuilder.runProtocolsWithReactor}-based tests.
350
351    The endpoint types in question here are still the TCP4 variety, since
352    these simply pass through IPv6 address literals to the reactor, and we are
353    only testing address literals, not name resolution (as name resolution has
354    not yet been implemented).  See http://twistedmatrix.com/trac/ticket/4470
355    for more specific information about new endpoint classes.  The naming is
356    slightly misleading, but presumably if you're passing an IPv6 literal, you
357    know what you're asking for.
358    """
359    def __init__(self):
360        self.interface = getLinkLocalIPv6Address()
361
362
363
364class TCPTransportTestsMixin(object):
365    """
366    Tests for functionality which is provided by any TCP transport (IPv4, IPv6,
367    client, server).
368    """
369    def test_largeSendBuffer(self):
370        """
371        Bytes written to a transport are sent reliably and in order even when
372        the local send buffer fills up because writes are done more quickly than
373        the network can accept them.
374        """
375        class LargeProtocol(ConnectableProtocol):
376            dataLen = 0
377
378            def dataReceived(self,data):
379                self.transport.write(data)
380                self.dataLen += len(data)
381
382        class LargeClientProtocol(ConnectableProtocol):
383            dataBuffer = b''
384
385            def connectionMade(self):
386                self.reactor.callLater(
387                    1, self.writeData)
388                self.checkd = None
389               
390            def writeData(self):
391                for i in range(0, repBlocks):
392                    self.transport.write(smallChunk*ops)
393                self.checkd = None
394               
395            def dataReceived(self,data):
396                self.dataBuffer += data
397                if not self.checkd:
398                    self.checkd = self.reactor.callLater(1, self.extraCheck)
399
400            def connectionLost(self, reason):
401                if self.checkd:
402                    self.checkd.cancel()
403                    self.checkd = None
404                ConnectableProtocol.connectionLost(self, reason)
405
406            def extraCheck(self):
407                if len(self.dataBuffer) == totalLen:
408                    self.normal = True
409                    self.transport.loseConnection()
410                elif len(self.dataBuffer) > totalLen:
411                    self.normal = False
412                    self.transport.loseConnection()
413                else:
414                    pass
415                self.checkd = None
416
417        smallChunk = b'X'
418        smallLen = len(smallChunk)
419        ops = 260000
420        repBlocks = 100
421        totalLen = smallLen * ops * repBlocks
422
423        server = LargeProtocol()
424        client = LargeClientProtocol()
425        reactor = runProtocolsWithReactor(
426            self, server, client, self.endpoints)
427
428        self.assertTrue(
429            client.normal,
430            "client received data is abnormal "
431            "(%d != %d)" % (len(client.dataBuffer), totalLen))
432
433
434
435class TCPClientTestsBase(ReactorBuilder, ConnectionTestsMixin,
436                            TCPClientTestsMixin, TCPTransportTestsMixin):
437    """
438    Base class for builders defining tests related to L{IReactorTCP.connectTCP}.
439    """
440    port = 1234
441
442    @property
443    def interface(self):
444        """
445        Return the interface attribute from the endpoints object.
446        """
447        return self.endpoints.interface
448
449
450
451class TCP4ClientTestsBuilder(TCPClientTestsBase):
452    """
453    Builder configured with IPv4 parameters for tests related to L{IReactorTCP.connectTCP}.
454    """
455    fakeDomainName = 'some-fake.domain.example.com'
456    family = socket.AF_INET
457    addressClass = IPv4Address
458
459    endpoints = TCPCreator()
460
461
462
463class TCP6ClientTestsBuilder(TCPClientTestsBase):
464    """
465    Builder configured with IPv6 parameters for tests related to L{IReactorTCP.connectTCP}.
466    """
467
468    if ipv6Skip:
469        skip = "Platform does not support ipv6"
470
471    family = socket.AF_INET6
472    addressClass = IPv6Address
473
474
475    def setUp(self):
476        # Only create this object here, so that it won't be created if tests
477        # are being skipped:
478        self.endpoints = TCP6Creator()
479        # This is used by test_addresses to test the distinction between the
480        # resolved name and the name on the socket itself.  All the same
481        # invariants should hold, but giving back an IPv6 address from a
482        # resolver is not something the reactor can handle, so instead, we make
483        # it so that the connect call for the IPv6 address test simply uses an
484        # address literal.
485        self.fakeDomainName = self.endpoints.interface
486
487
488
489class TCPConnectorTestsBuilder(ReactorBuilder):
490
491    def test_connectorIdentity(self):
492        """
493        L{IReactorTCP.connectTCP} returns an object which provides
494        L{IConnector}.  The destination of the connector is the address which
495        was passed to C{connectTCP}.  The same connector object is passed to
496        the factory's C{startedConnecting} method as to the factory's
497        C{clientConnectionLost} method.
498        """
499        serverFactory = ClosingFactory()
500        reactor = self.buildReactor()
501        tcpPort = reactor.listenTCP(0, serverFactory, interface=self.interface)
502        serverFactory.port = tcpPort
503        portNumber = tcpPort.getHost().port
504
505        seenConnectors = []
506        seenFailures = []
507
508        clientFactory = ClientStartStopFactory()
509        clientFactory.clientConnectionLost = (
510            lambda connector, reason: (seenConnectors.append(connector),
511                                       seenFailures.append(reason)))
512        clientFactory.startedConnecting = seenConnectors.append
513
514        connector = reactor.connectTCP(self.interface, portNumber,
515                                       clientFactory)
516        self.assertTrue(IConnector.providedBy(connector))
517        dest = connector.getDestination()
518        self.assertEqual(dest.type, "TCP")
519        self.assertEqual(dest.host, self.interface)
520        self.assertEqual(dest.port, portNumber)
521
522        clientFactory.whenStopped.addBoth(lambda _: reactor.stop())
523
524        self.runReactor(reactor)
525
526        seenFailures[0].trap(ConnectionDone)
527        self.assertEqual(seenConnectors, [connector, connector])
528
529
530    def test_userFail(self):
531        """
532        Calling L{IConnector.stopConnecting} in C{Factory.startedConnecting}
533        results in C{Factory.clientConnectionFailed} being called with
534        L{error.UserError} as the reason.
535        """
536        serverFactory = MyServerFactory()
537        reactor = self.buildReactor()
538        tcpPort = reactor.listenTCP(0, serverFactory, interface=self.interface)
539        portNumber = tcpPort.getHost().port
540
541        fatalErrors = []
542
543        def startedConnecting(connector):
544            try:
545                connector.stopConnecting()
546            except Exception:
547                fatalErrors.append(Failure())
548                reactor.stop()
549
550        clientFactory = ClientStartStopFactory()
551        clientFactory.startedConnecting = startedConnecting
552
553        clientFactory.whenStopped.addBoth(lambda _: reactor.stop())
554
555        reactor.callWhenRunning(lambda: reactor.connectTCP(self.interface,
556                                                           portNumber,
557                                                           clientFactory))
558
559        self.runReactor(reactor)
560
561        if fatalErrors:
562            self.fail(fatalErrors[0].getTraceback())
563        clientFactory.reason.trap(UserError)
564        self.assertEqual(clientFactory.failed, 1)
565
566
567    def test_reconnect(self):
568        """
569        Calling L{IConnector.connect} in C{Factory.clientConnectionLost} causes
570        a new connection attempt to be made.
571        """
572        serverFactory = ClosingFactory()
573        reactor = self.buildReactor()
574        tcpPort = reactor.listenTCP(0, serverFactory, interface=self.interface)
575        serverFactory.port = tcpPort
576        portNumber = tcpPort.getHost().port
577
578        clientFactory = MyClientFactory()
579
580        def clientConnectionLost(connector, reason):
581            connector.connect()
582        clientFactory.clientConnectionLost = clientConnectionLost
583        reactor.connectTCP(self.interface, portNumber, clientFactory)
584
585        protocolMadeAndClosed = []
586        def reconnectFailed(ignored):
587            p = clientFactory.protocol
588            protocolMadeAndClosed.append((p.made, p.closed))
589            reactor.stop()
590
591        clientFactory.failDeferred.addCallback(reconnectFailed)
592
593        self.runReactor(reactor)
594
595        clientFactory.reason.trap(ConnectionRefusedError)
596        self.assertEqual(protocolMadeAndClosed, [(1, 1)])
597
598
599
600class TCP4ConnectorTestsBuilder(TCPConnectorTestsBuilder):
601    interface = '127.0.0.1'
602    family = socket.AF_INET
603    addressClass = IPv4Address
604
605
606
607class TCP6ConnectorTestsBuilder(TCPConnectorTestsBuilder):
608    family = socket.AF_INET6
609    addressClass = IPv6Address
610
611    if ipv6Skip:
612        skip = "Platform does not support ipv6"
613
614    def setUp(self):
615        self.interface = getLinkLocalIPv6Address()
616
617
618
619def createTestSocket(test, addressFamily, socketType):
620    """
621    Create a socket for the duration of the given test.
622
623    @param test: the test to add cleanup to.
624
625    @param addressFamily: an C{AF_*} constant
626
627    @param socketType: a C{SOCK_*} constant.
628
629    @return: a socket object.
630    """
631    skt = socket.socket(addressFamily, socketType)
632    test.addCleanup(skt.close)
633    return skt
634
635
636
637class StreamTransportTestsMixin(LogObserverMixin):
638    """
639    Mixin defining tests which apply to any port/connection based transport.
640    """
641    def test_startedListeningLogMessage(self):
642        """
643        When a port starts, a message including a description of the associated
644        factory is logged.
645        """
646        loggedMessages = self.observe()
647        reactor = self.buildReactor()
648        class SomeFactory(ServerFactory):
649            implements(ILoggingContext)
650            def logPrefix(self):
651                return "Crazy Factory"
652        factory = SomeFactory()
653        p = self.getListeningPort(reactor, factory)
654        expectedMessage = self.getExpectedStartListeningLogMessage(
655            p, "Crazy Factory")
656        self.assertEqual((expectedMessage,), loggedMessages[0]['message'])
657
658
659    def test_connectionLostLogMsg(self):
660        """
661        When a connection is lost, an informative message should be logged
662        (see L{getExpectedConnectionLostLogMsg}): an address identifying
663        the port and the fact that it was closed.
664        """
665
666        loggedMessages = []
667        def logConnectionLostMsg(eventDict):
668            loggedMessages.append(log.textFromEventDict(eventDict))
669
670        reactor = self.buildReactor()
671        p = self.getListeningPort(reactor, ServerFactory())
672        expectedMessage = self.getExpectedConnectionLostLogMsg(p)
673        log.addObserver(logConnectionLostMsg)
674
675        def stopReactor(ignored):
676            log.removeObserver(logConnectionLostMsg)
677            reactor.stop()
678
679        def doStopListening():
680            log.addObserver(logConnectionLostMsg)
681            maybeDeferred(p.stopListening).addCallback(stopReactor)
682
683        reactor.callWhenRunning(doStopListening)
684        reactor.run()
685
686        self.assertIn(expectedMessage, loggedMessages)
687
688
689    def test_allNewStyle(self):
690        """
691        The L{IListeningPort} object is an instance of a class with no
692        classic classes in its hierarchy.
693        """
694        reactor = self.buildReactor()
695        port = self.getListeningPort(reactor, ServerFactory())
696        self.assertFullyNewStyle(port)
697
698
699class ListenTCPMixin(object):
700    """
701    Mixin which uses L{IReactorTCP.listenTCP} to hand out listening TCP ports.
702    """
703    def getListeningPort(self, reactor, factory, port=0, interface=''):
704        """
705        Get a TCP port from a reactor.
706        """
707        return reactor.listenTCP(port, factory, interface=interface)
708
709
710
711class SocketTCPMixin(object):
712    """
713    Mixin which uses L{IReactorSocket.adoptStreamPort} to hand out listening TCP
714    ports.
715    """
716    def getListeningPort(self, reactor, factory, port=0, interface=''):
717        """
718        Get a TCP port from a reactor, wrapping an already-initialized file
719        descriptor.
720        """
721        if IReactorSocket.providedBy(reactor):
722            if ':' in interface:
723                domain = socket.AF_INET6
724                address = socket.getaddrinfo(interface, port)[0][4]
725            else:
726                domain = socket.AF_INET
727                address = (interface, port)
728            portSock = socket.socket(domain)
729            portSock.bind(address)
730            portSock.listen(3)
731            portSock.setblocking(False)
732            try:
733                return reactor.adoptStreamPort(
734                    portSock.fileno(), portSock.family, factory)
735            finally:
736                # The socket should still be open; fileno will raise if it is
737                # not.
738                portSock.fileno()
739                # Now clean it up, because the rest of the test does not need
740                # it.
741                portSock.close()
742        else:
743            raise SkipTest("Reactor does not provide IReactorSocket")
744
745
746
747class TCPPortTestsMixin(object):
748    """
749    Tests for L{IReactorTCP.listenTCP}
750    """
751    def getExpectedStartListeningLogMessage(self, port, factory):
752        """
753        Get the message expected to be logged when a TCP port starts listening.
754        """
755        return "%s starting on %d" % (
756            factory, port.getHost().port)
757
758
759    def getExpectedConnectionLostLogMsg(self, port):
760        """
761        Get the expected connection lost message for a TCP port.
762        """
763        return "(TCP Port %s Closed)" % (port.getHost().port,)
764
765
766    def test_portGetHostOnIPv4(self):
767        """
768        When no interface is passed to L{IReactorTCP.listenTCP}, the returned
769        listening port listens on an IPv4 address.
770        """
771        reactor = self.buildReactor()
772        port = self.getListeningPort(reactor, ServerFactory())
773        address = port.getHost()
774        self.assertIsInstance(address, IPv4Address)
775
776
777    def test_portGetHostOnIPv6(self):
778        """
779        When listening on an IPv6 address, L{IListeningPort.getHost} returns
780        an L{IPv6Address} with C{host} and C{port} attributes reflecting the
781        address the port is bound to.
782        """
783        reactor = self.buildReactor()
784        host, portNumber = findFreePort(
785            family=socket.AF_INET6, interface='::1')[:2]
786        port = self.getListeningPort(
787            reactor, ServerFactory(), portNumber, host)
788        address = port.getHost()
789        self.assertIsInstance(address, IPv6Address)
790        self.assertEqual('::1', address.host)
791        self.assertEqual(portNumber, address.port)
792    if ipv6Skip:
793        test_portGetHostOnIPv6.skip = ipv6Skip
794
795
796    def test_portGetHostOnIPv6ScopeID(self):
797        """
798        When a link-local IPv6 address including a scope identifier is passed as
799        the C{interface} argument to L{IReactorTCP.listenTCP}, the resulting
800        L{IListeningPort} reports its address as an L{IPv6Address} with a host
801        value that includes the scope identifier.
802        """
803        linkLocal = getLinkLocalIPv6Address()
804        reactor = self.buildReactor()
805        port = self.getListeningPort(reactor, ServerFactory(), 0, linkLocal)
806        address = port.getHost()
807        self.assertIsInstance(address, IPv6Address)
808        self.assertEqual(linkLocal, address.host)
809    if ipv6Skip:
810        test_portGetHostOnIPv6ScopeID.skip = ipv6Skip
811
812
813    def _buildProtocolAddressTest(self, client, interface):
814        """
815        Connect C{client} to a server listening on C{interface} started with
816        L{IReactorTCP.listenTCP} and return the address passed to the factory's
817        C{buildProtocol} method.
818
819        @param client: A C{SOCK_STREAM} L{socket.socket} created with an address
820            family such that it will be able to connect to a server listening on
821            C{interface}.
822
823        @param interface: A C{str} giving an address for a server to listen on.
824            This should almost certainly be the loopback address for some
825            address family supported by L{IReactorTCP.listenTCP}.
826
827        @return: Whatever object, probably an L{IAddress} provider, is passed to
828            a server factory's C{buildProtocol} method when C{client}
829            establishes a connection.
830        """
831        class ObserveAddress(ServerFactory):
832            def buildProtocol(self, address):
833                reactor.stop()
834                self.observedAddress = address
835                return Protocol()
836
837        factory = ObserveAddress()
838        reactor = self.buildReactor()
839        port = self.getListeningPort(reactor, factory, 0, interface)
840        client.setblocking(False)
841        try:
842            connect(client, (port.getHost().host, port.getHost().port))
843        except socket.error, (errnum, message):
844            self.assertIn(errnum, (errno.EINPROGRESS, errno.EWOULDBLOCK))
845
846        self.runReactor(reactor)
847
848        return factory.observedAddress
849
850
851    def test_buildProtocolIPv4Address(self):
852        """
853        When a connection is accepted over IPv4, an L{IPv4Address} is passed
854        to the factory's C{buildProtocol} method giving the peer's address.
855        """
856        interface = '127.0.0.1'
857        client = createTestSocket(self, socket.AF_INET, socket.SOCK_STREAM)
858        observedAddress = self._buildProtocolAddressTest(client, interface)
859        self.assertEqual(
860            IPv4Address('TCP', *client.getsockname()), observedAddress)
861
862
863    def test_buildProtocolIPv6Address(self):
864        """
865        When a connection is accepted to an IPv6 address, an L{IPv6Address} is
866        passed to the factory's C{buildProtocol} method giving the peer's
867        address.
868        """
869        interface = '::1'
870        client = createTestSocket(self, socket.AF_INET6, socket.SOCK_STREAM)
871        observedAddress = self._buildProtocolAddressTest(client, interface)
872        self.assertEqual(
873            IPv6Address('TCP', *client.getsockname()[:2]), observedAddress)
874    if ipv6Skip:
875        test_buildProtocolIPv6Address.skip = ipv6Skip
876
877
878    def test_buildProtocolIPv6AddressScopeID(self):
879        """
880        When a connection is accepted to a link-local IPv6 address, an
881        L{IPv6Address} is passed to the factory's C{buildProtocol} method
882        giving the peer's address, including a scope identifier.
883        """
884        interface = getLinkLocalIPv6Address()
885        client = createTestSocket(self, socket.AF_INET6, socket.SOCK_STREAM)
886        observedAddress = self._buildProtocolAddressTest(client, interface)
887        self.assertEqual(
888            IPv6Address('TCP', *client.getsockname()[:2]), observedAddress)
889    if ipv6Skip:
890        test_buildProtocolIPv6AddressScopeID.skip = ipv6Skip
891
892
893    def _serverGetConnectionAddressTest(self, client, interface, which):
894        """
895        Connect C{client} to a server listening on C{interface} started with
896        L{IReactorTCP.listenTCP} and return the address returned by one of the
897        server transport's address lookup methods, C{getHost} or C{getPeer}.
898
899        @param client: A C{SOCK_STREAM} L{socket.socket} created with an address
900            family such that it will be able to connect to a server listening on
901            C{interface}.
902
903        @param interface: A C{str} giving an address for a server to listen on.
904            This should almost certainly be the loopback address for some
905            address family supported by L{IReactorTCP.listenTCP}.
906
907        @param which: A C{str} equal to either C{"getHost"} or C{"getPeer"}
908            determining which address will be returned.
909
910        @return: Whatever object, probably an L{IAddress} provider, is returned
911            from the method indicated by C{which}.
912        """
913        class ObserveAddress(Protocol):
914            def makeConnection(self, transport):
915                reactor.stop()
916                self.factory.address = getattr(transport, which)()
917
918        reactor = self.buildReactor()
919        factory = ServerFactory()
920        factory.protocol = ObserveAddress
921        port = self.getListeningPort(reactor, factory, 0, interface)
922        client.setblocking(False)
923        try:
924            connect(client, (port.getHost().host, port.getHost().port))
925        except socket.error, (errnum, message):
926            self.assertIn(errnum, (errno.EINPROGRESS, errno.EWOULDBLOCK))
927        self.runReactor(reactor)
928        return factory.address
929
930
931    def test_serverGetHostOnIPv4(self):
932        """
933        When a connection is accepted over IPv4, the server
934        L{ITransport.getHost} method returns an L{IPv4Address} giving the
935        address on which the server accepted the connection.
936        """
937        interface = '127.0.0.1'
938        client = createTestSocket(self, socket.AF_INET, socket.SOCK_STREAM)
939        hostAddress = self._serverGetConnectionAddressTest(
940            client, interface, 'getHost')
941        self.assertEqual(
942            IPv4Address('TCP', *client.getpeername()), hostAddress)
943
944
945    def test_serverGetHostOnIPv6(self):
946        """
947        When a connection is accepted over IPv6, the server
948        L{ITransport.getHost} method returns an L{IPv6Address} giving the
949        address on which the server accepted the connection.
950        """
951        interface = '::1'
952        client = createTestSocket(self, socket.AF_INET6, socket.SOCK_STREAM)
953        hostAddress = self._serverGetConnectionAddressTest(
954            client, interface, 'getHost')
955        self.assertEqual(
956            IPv6Address('TCP', *client.getpeername()[:2]), hostAddress)
957    if ipv6Skip:
958        test_serverGetHostOnIPv6.skip = ipv6Skip
959
960
961    def test_serverGetHostOnIPv6ScopeID(self):
962        """
963        When a connection is accepted over IPv6, the server
964        L{ITransport.getHost} method returns an L{IPv6Address} giving the
965        address on which the server accepted the connection, including the scope
966        identifier.
967        """
968        interface = getLinkLocalIPv6Address()
969        client = createTestSocket(self, socket.AF_INET6, socket.SOCK_STREAM)
970        hostAddress = self._serverGetConnectionAddressTest(
971            client, interface, 'getHost')
972        self.assertEqual(
973            IPv6Address('TCP', *client.getpeername()[:2]), hostAddress)
974    if ipv6Skip:
975        test_serverGetHostOnIPv6ScopeID.skip = ipv6Skip
976
977
978    def test_serverGetPeerOnIPv4(self):
979        """
980        When a connection is accepted over IPv4, the server
981        L{ITransport.getPeer} method returns an L{IPv4Address} giving the
982        address of the remote end of the connection.
983        """
984        interface = '127.0.0.1'
985        client = createTestSocket(self, socket.AF_INET, socket.SOCK_STREAM)
986        peerAddress = self._serverGetConnectionAddressTest(
987            client, interface, 'getPeer')
988        self.assertEqual(
989            IPv4Address('TCP', *client.getsockname()), peerAddress)
990
991
992    def test_serverGetPeerOnIPv6(self):
993        """
994        When a connection is accepted over IPv6, the server
995        L{ITransport.getPeer} method returns an L{IPv6Address} giving the
996        address on the remote end of the connection.
997        """
998        interface = '::1'
999        client = createTestSocket(self, socket.AF_INET6, socket.SOCK_STREAM)
1000        peerAddress = self._serverGetConnectionAddressTest(
1001            client, interface, 'getPeer')
1002        self.assertEqual(
1003            IPv6Address('TCP', *client.getsockname()[:2]), peerAddress)
1004    if ipv6Skip:
1005        test_serverGetPeerOnIPv6.skip = ipv6Skip
1006
1007
1008    def test_serverGetPeerOnIPv6ScopeID(self):
1009        """
1010        When a connection is accepted over IPv6, the server
1011        L{ITransport.getPeer} method returns an L{IPv6Address} giving the
1012        address on the remote end of the connection, including the scope
1013        identifier.
1014        """
1015        interface = getLinkLocalIPv6Address()
1016        client = createTestSocket(self, socket.AF_INET6, socket.SOCK_STREAM)
1017        peerAddress = self._serverGetConnectionAddressTest(
1018            client, interface, 'getPeer')
1019        self.assertEqual(
1020            IPv6Address('TCP', *client.getsockname()[:2]), peerAddress)
1021    if ipv6Skip:
1022        test_serverGetPeerOnIPv6ScopeID.skip = ipv6Skip
1023
1024
1025
1026class TCPPortTestsBuilder(ReactorBuilder, ListenTCPMixin, TCPPortTestsMixin,
1027                          ObjectModelIntegrationMixin,
1028                          StreamTransportTestsMixin):
1029    pass
1030
1031
1032
1033class TCPFDPortTestsBuilder(ReactorBuilder, SocketTCPMixin, TCPPortTestsMixin,
1034                            ObjectModelIntegrationMixin,
1035                            StreamTransportTestsMixin):
1036    pass
1037
1038
1039
1040class StopStartReadingProtocol(Protocol):
1041    """
1042    Protocol that pauses and resumes the transport a few times
1043    """
1044
1045    def connectionMade(self):
1046        self.data = ''
1047        self.pauseResumeProducing(3)
1048
1049
1050    def pauseResumeProducing(self, counter):
1051        """
1052        Toggle transport read state, then count down.
1053        """
1054        self.transport.pauseProducing()
1055        self.transport.resumeProducing()
1056        if counter:
1057            self.factory.reactor.callLater(0,
1058                    self.pauseResumeProducing, counter - 1)
1059        else:
1060            self.factory.reactor.callLater(0,
1061                    self.factory.ready.callback, self)
1062
1063
1064    def dataReceived(self, data):
1065        log.msg('got data', len(data))
1066        self.data += data
1067        if len(self.data) == 4*4096:
1068            self.factory.stop.callback(self.data)
1069
1070
1071
1072class TCPConnectionTestsBuilder(ReactorBuilder):
1073    """
1074    Builder defining tests relating to L{twisted.internet.tcp.Connection}.
1075    """
1076
1077    def test_stopStartReading(self):
1078        """
1079        This test verifies transport socket read state after multiple
1080        pause/resumeProducing calls.
1081        """
1082        sf = ServerFactory()
1083        reactor = sf.reactor = self.buildReactor()
1084
1085        skippedReactors = ["Glib2Reactor", "Gtk2Reactor"]
1086        reactorClassName = reactor.__class__.__name__
1087        if reactorClassName in skippedReactors and platform.isWindows():
1088            raise SkipTest(
1089                "This test is broken on gtk/glib under Windows.")
1090
1091        sf.protocol = StopStartReadingProtocol
1092        sf.ready = Deferred()
1093        sf.stop = Deferred()
1094        p = reactor.listenTCP(0, sf)
1095        port = p.getHost().port
1096        def proceed(protos, port):
1097            """
1098            Send several IOCPReactor's buffers' worth of data.
1099            """
1100            self.assertTrue(protos[0])
1101            self.assertTrue(protos[1])
1102            protos = protos[0][1], protos[1][1]
1103            protos[0].transport.write('x' * (2 * 4096) + 'y' * (2 * 4096))
1104            return (sf.stop.addCallback(cleanup, protos, port)
1105                           .addCallback(lambda ign: reactor.stop()))
1106
1107        def cleanup(data, protos, port):
1108            """
1109            Make sure IOCPReactor didn't start several WSARecv operations
1110            that clobbered each other's results.
1111            """
1112            self.assertEqual(data, 'x'*(2*4096) + 'y'*(2*4096),
1113                                 'did not get the right data')
1114            return DeferredList([
1115                    maybeDeferred(protos[0].transport.loseConnection),
1116                    maybeDeferred(protos[1].transport.loseConnection),
1117                    maybeDeferred(port.stopListening)])
1118
1119        cc = TCP4ClientEndpoint(reactor, '127.0.0.1', port)
1120        cf = ClientFactory()
1121        cf.protocol = Protocol
1122        d = DeferredList([cc.connect(cf), sf.ready]).addCallback(proceed, p)
1123        self.runReactor(reactor)
1124        return d
1125
1126
1127    def test_connectionLostAfterPausedTransport(self):
1128        """
1129        Alice connects to Bob.  Alice writes some bytes and then shuts down the
1130        connection.  Bob receives the bytes from the connection and then pauses
1131        the transport object.  Shortly afterwards Bob resumes the transport
1132        object.  At that point, Bob is notified that the connection has been
1133        closed.
1134
1135        This is no problem for most reactors.  The underlying event notification
1136        API will probably just remind them that the connection has been closed.
1137        It is a little tricky for win32eventreactor (MsgWaitForMultipleObjects).
1138        MsgWaitForMultipleObjects will only deliver the close notification once.
1139        The reactor needs to remember that notification until Bob resumes the
1140        transport.
1141        """
1142        class Pauser(ConnectableProtocol):
1143            def __init__(self):
1144                self.events = []
1145
1146            def dataReceived(self, bytes):
1147                self.events.append("paused")
1148                self.transport.pauseProducing()
1149                self.reactor.callLater(0, self.resume)
1150
1151            def resume(self):
1152                self.events.append("resumed")
1153                self.transport.resumeProducing()
1154
1155            def connectionLost(self, reason):
1156                # This is the event you have been waiting for.
1157                self.events.append("lost")
1158                ConnectableProtocol.connectionLost(self, reason)
1159
1160        class Client(ConnectableProtocol):
1161            def connectionMade(self):
1162                self.transport.write("some bytes for you")
1163                self.transport.loseConnection()
1164
1165        pauser = Pauser()
1166        runProtocolsWithReactor(self, pauser, Client(), TCPCreator())
1167        self.assertEqual(pauser.events, ["paused", "resumed", "lost"])
1168
1169
1170    def test_doubleHalfClose(self):
1171        """
1172        If one side half-closes its connection, and then the other side of the
1173        connection calls C{loseWriteConnection}, and then C{loseConnection} in
1174        {writeConnectionLost}, the connection is closed correctly.
1175
1176        This rather obscure case used to fail (see ticket #3037).
1177        """
1178        class ListenerProtocol(ConnectableProtocol):
1179            implements(IHalfCloseableProtocol)
1180
1181            def readConnectionLost(self):
1182                self.transport.loseWriteConnection()
1183
1184            def writeConnectionLost(self):
1185                self.transport.loseConnection()
1186
1187        class Client(ConnectableProtocol):
1188            def connectionMade(self):
1189                self.transport.loseConnection()
1190
1191        # If test fails, reactor won't stop and we'll hit timeout:
1192        runProtocolsWithReactor(
1193            self, ListenerProtocol(), Client(), TCPCreator())
1194
1195
1196
1197class WriteSequenceTests(ReactorBuilder):
1198    """
1199    Test for L{twisted.internet.abstract.FileDescriptor.writeSequence}.
1200
1201    @ivar client: the connected client factory to be used in tests.
1202    @type client: L{MyClientFactory}
1203
1204    @ivar server: the listening server factory to be used in tests.
1205    @type server: L{MyServerFactory}
1206    """
1207    def setUp(self):
1208        server = MyServerFactory()
1209        server.protocolConnectionMade = Deferred()
1210        server.protocolConnectionLost = Deferred()
1211        self.server = server
1212
1213        client = MyClientFactory()
1214        client.protocolConnectionMade = Deferred()
1215        client.protocolConnectionLost = Deferred()
1216        self.client = client
1217
1218
1219    def setWriteBufferSize(self, transport, value):
1220        """
1221        Set the write buffer size for the given transport, mananing possible
1222        differences (ie, IOCP). Bug #4322 should remove the need of that hack.
1223        """
1224        if getattr(transport, "writeBufferSize", None) is not None:
1225            transport.writeBufferSize = value
1226        else:
1227            transport.bufferSize = value
1228
1229
1230    def test_withoutWrite(self):
1231        """
1232        C{writeSequence} sends the data even if C{write} hasn't been called.
1233        """
1234        client, server = self.client, self.server
1235        reactor = self.buildReactor()
1236
1237        port = reactor.listenTCP(0, server)
1238
1239        def dataReceived(data):
1240            log.msg("data received: %r" % data)
1241            self.assertEquals(data, "Some sequence splitted")
1242            client.protocol.transport.loseConnection()
1243
1244        def clientConnected(proto):
1245            log.msg("client connected %s" % proto)
1246            proto.transport.writeSequence(["Some ", "sequence ", "splitted"])
1247
1248        def serverConnected(proto):
1249            log.msg("server connected %s" % proto)
1250            proto.dataReceived = dataReceived
1251
1252        d1 = client.protocolConnectionMade.addCallback(clientConnected)
1253        d2 = server.protocolConnectionMade.addCallback(serverConnected)
1254        d3 = server.protocolConnectionLost
1255        d4 = client.protocolConnectionLost
1256        d = gatherResults([d1, d2, d3, d4])
1257        def stop(result):
1258            reactor.stop()
1259            return result
1260        d.addBoth(stop)
1261
1262        reactor.connectTCP("127.0.0.1", port.getHost().port, client)
1263        self.runReactor(reactor)
1264
1265
1266    def test_writeSequenceWithUnicodeRaisesException(self):
1267        """
1268        C{writeSequence} with an element in the sequence of type unicode raises
1269        C{TypeError}.
1270        """
1271        client, server = self.client, self.server
1272        reactor = self.buildReactor()
1273
1274        port = reactor.listenTCP(0, server)
1275
1276        reactor.connectTCP("127.0.0.1", port.getHost().port, client)
1277
1278        def serverConnected(proto):
1279            log.msg("server connected %s" % proto)
1280            exc = self.assertRaises(
1281                TypeError,
1282                proto.transport.writeSequence, [u"Unicode is not kosher"])
1283            self.assertEquals(str(exc), "Data must not be unicode")
1284
1285        d = server.protocolConnectionMade.addCallback(serverConnected)
1286        d.addErrback(log.err)
1287        d.addCallback(lambda ignored: reactor.stop())
1288
1289        self.runReactor(reactor)
1290
1291
1292    def _producerTest(self, clientConnected):
1293        """
1294        Helper for testing producers which call C{writeSequence}.  This will set
1295        up a connection which a producer can use.  It returns after the
1296        connection is closed.
1297
1298        @param clientConnected: A callback which will be invoked with a client
1299            protocol after a connection is setup.  This is responsible for
1300            setting up some sort of producer.
1301        """
1302        reactor = self.buildReactor()
1303
1304        port = reactor.listenTCP(0, self.server)
1305
1306        # The following could probably all be much simpler, but for #5285.
1307
1308        # First let the server notice the connection
1309        d1 = self.server.protocolConnectionMade
1310
1311        # Grab the client connection Deferred now though, so we don't lose it if
1312        # the client connects before the server.
1313        d2 = self.client.protocolConnectionMade
1314
1315        def serverConnected(proto):
1316            # Now take action as soon as the client is connected
1317            d2.addCallback(clientConnected)
1318            return d2
1319        d1.addCallback(serverConnected)
1320
1321        d3 = self.server.protocolConnectionLost
1322        d4 = self.client.protocolConnectionLost
1323
1324        # After the client is connected and does its producer stuff, wait for
1325        # the disconnection events.
1326        def didProducerActions(ignored):
1327            return gatherResults([d3, d4])
1328        d1.addCallback(didProducerActions)
1329
1330        def stop(result):
1331            reactor.stop()
1332            return result
1333        d1.addBoth(stop)
1334
1335        reactor.connectTCP("127.0.0.1", port.getHost().port, self.client)
1336
1337        self.runReactor(reactor)
1338
1339
1340    def test_streamingProducer(self):
1341        """
1342        C{writeSequence} pauses its streaming producer if too much data is
1343        buffered, and then resumes it.
1344        """
1345        client, server = self.client, self.server
1346
1347        class SaveActionProducer(object):
1348            implements(IPushProducer)
1349            def __init__(self):
1350                self.actions = []
1351
1352            def pauseProducing(self):
1353                self.actions.append("pause")
1354
1355            def resumeProducing(self):
1356                self.actions.append("resume")
1357                # Unregister the producer so the connection can close
1358                client.protocol.transport.unregisterProducer()
1359                # This is why the code below waits for the server connection
1360                # first - so we have it to close here.  We close the server side
1361                # because win32evenreactor cannot reliably observe us closing
1362                # the client side (#5285).
1363                server.protocol.transport.loseConnection()
1364
1365            def stopProducing(self):
1366                self.actions.append("stop")
1367
1368        producer = SaveActionProducer()
1369
1370        def clientConnected(proto):
1371            # Register a streaming producer and verify that it gets paused after
1372            # it writes more than the local send buffer can hold.
1373            proto.transport.registerProducer(producer, True)
1374            self.assertEquals(producer.actions, [])
1375            self.setWriteBufferSize(proto.transport, 500)
1376            proto.transport.writeSequence(["x" * 50] * 20)
1377            self.assertEquals(producer.actions, ["pause"])
1378
1379        self._producerTest(clientConnected)
1380        # After the send buffer gets a chance to empty out a bit, the producer
1381        # should be resumed.
1382        self.assertEquals(producer.actions, ["pause", "resume"])
1383
1384
1385    def test_nonStreamingProducer(self):
1386        """
1387        C{writeSequence} pauses its producer if too much data is buffered only
1388        if this is a streaming producer.
1389        """
1390        client, server = self.client, self.server
1391        test = self
1392
1393        class SaveActionProducer(object):
1394            implements(IPullProducer)
1395            def __init__(self):
1396                self.actions = []
1397
1398            def resumeProducing(self):
1399                self.actions.append("resume")
1400                if self.actions.count("resume") == 2:
1401                    client.protocol.transport.stopConsuming()
1402                else:
1403                    test.setWriteBufferSize(client.protocol.transport, 500)
1404                    client.protocol.transport.writeSequence(["x" * 50] * 20)
1405
1406            def stopProducing(self):
1407                self.actions.append("stop")
1408
1409        producer = SaveActionProducer()
1410
1411        def clientConnected(proto):
1412            # Register a non-streaming producer and verify that it is resumed
1413            # immediately.
1414            proto.transport.registerProducer(producer, False)
1415            self.assertEquals(producer.actions, ["resume"])
1416
1417        self._producerTest(clientConnected)
1418        # After the local send buffer empties out, the producer should be
1419        # resumed again.
1420        self.assertEquals(producer.actions, ["resume", "resume"])
1421
1422
1423globals().update(TCP4ClientTestsBuilder.makeTestCaseClasses())
1424globals().update(TCP6ClientTestsBuilder.makeTestCaseClasses())
1425globals().update(TCPPortTestsBuilder.makeTestCaseClasses())
1426globals().update(TCPFDPortTestsBuilder.makeTestCaseClasses())
1427globals().update(TCPConnectionTestsBuilder.makeTestCaseClasses())
1428globals().update(TCP4ConnectorTestsBuilder.makeTestCaseClasses())
1429globals().update(TCP6ConnectorTestsBuilder.makeTestCaseClasses())
1430globals().update(WriteSequenceTests.makeTestCaseClasses())
1431
1432
1433
1434class ServerAbortsTwice(ConnectableProtocol):
1435    """
1436    Call abortConnection() twice.
1437    """
1438
1439    def dataReceived(self, data):
1440        self.transport.abortConnection()
1441        self.transport.abortConnection()
1442
1443
1444
1445class ServerAbortsThenLoses(ConnectableProtocol):
1446    """
1447    Call abortConnection() followed by loseConnection().
1448    """
1449
1450    def dataReceived(self, data):
1451        self.transport.abortConnection()
1452        self.transport.loseConnection()
1453
1454
1455
1456class AbortServerWritingProtocol(ConnectableProtocol):
1457    """
1458    Protocol that writes data upon connection.
1459    """
1460
1461    def connectionMade(self):
1462        """
1463        Tell the client that the connection is set up and it's time to abort.
1464        """
1465        self.transport.write("ready")
1466
1467
1468
1469class ReadAbortServerProtocol(AbortServerWritingProtocol):
1470    """
1471    Server that should never receive any data, except 'X's which are written
1472    by the other side of the connection before abortConnection, and so might
1473    possibly arrive.
1474    """
1475
1476    def dataReceived(self, data):
1477        if data.replace('X', ''):
1478            raise Exception("Unexpectedly received data.")
1479
1480
1481
1482class NoReadServer(ConnectableProtocol):
1483    """
1484    Stop reading immediately on connection.
1485
1486    This simulates a lost connection that will cause the other side to time
1487    out, and therefore call abortConnection().
1488    """
1489
1490    def connectionMade(self):
1491        self.transport.stopReading()
1492
1493
1494
1495class EventualNoReadServer(ConnectableProtocol):
1496    """
1497    Like NoReadServer, except we Wait until some bytes have been delivered
1498    before stopping reading. This means TLS handshake has finished, where
1499    applicable.
1500    """
1501
1502    gotData = False
1503    stoppedReading = False
1504
1505
1506    def dataReceived(self, data):
1507        if not self.gotData:
1508            self.gotData = True
1509            self.transport.registerProducer(self, False)
1510            self.transport.write("hello")
1511
1512
1513    def resumeProducing(self):
1514        if self.stoppedReading:
1515            return
1516        self.stoppedReading = True
1517        # We've written out the data:
1518        self.transport.stopReading()
1519
1520
1521    def pauseProducing(self):
1522        pass
1523
1524
1525    def stopProducing(self):
1526        pass
1527
1528
1529
1530class BaseAbortingClient(ConnectableProtocol):
1531    """
1532    Base class for abort-testing clients.
1533    """
1534    inReactorMethod = False
1535
1536    def connectionLost(self, reason):
1537        if self.inReactorMethod:
1538            raise RuntimeError("BUG: connectionLost was called re-entrantly!")
1539        ConnectableProtocol.connectionLost(self, reason)
1540
1541
1542
1543class WritingButNotAbortingClient(BaseAbortingClient):
1544    """
1545    Write data, but don't abort.
1546    """
1547
1548    def connectionMade(self):
1549        self.transport.write("hello")
1550
1551
1552
1553class AbortingClient(BaseAbortingClient):
1554    """
1555    Call abortConnection() after writing some data.
1556    """
1557
1558    def dataReceived(self, data):
1559        """
1560        Some data was received, so the connection is set up.
1561        """
1562        self.inReactorMethod = True
1563        self.writeAndAbort()
1564        self.inReactorMethod = False
1565
1566
1567    def writeAndAbort(self):
1568        # X is written before abortConnection, and so there is a chance it
1569        # might arrive. Y is written after, and so no Ys should ever be
1570        # delivered:
1571        self.transport.write("X" * 10000)
1572        self.transport.abortConnection()
1573        self.transport.write("Y" * 10000)
1574
1575
1576
1577class AbortingTwiceClient(AbortingClient):
1578    """
1579    Call abortConnection() twice, after writing some data.
1580    """
1581
1582    def writeAndAbort(self):
1583        AbortingClient.writeAndAbort(self)
1584        self.transport.abortConnection()
1585
1586
1587
1588class AbortingThenLosingClient(AbortingClient):
1589    """
1590    Call abortConnection() and then loseConnection().
1591    """
1592
1593    def writeAndAbort(self):
1594        AbortingClient.writeAndAbort(self)
1595        self.transport.loseConnection()
1596
1597
1598
1599class ProducerAbortingClient(ConnectableProtocol):
1600    """
1601    Call abortConnection from doWrite, via resumeProducing.
1602    """
1603
1604    inReactorMethod = True
1605    producerStopped = False
1606
1607    def write(self):
1608        self.transport.write("lalala" * 127000)
1609        self.inRegisterProducer = True
1610        self.transport.registerProducer(self, False)
1611        self.inRegisterProducer = False
1612
1613
1614    def connectionMade(self):
1615        self.write()
1616
1617
1618    def resumeProducing(self):
1619        self.inReactorMethod = True
1620        if not self.inRegisterProducer:
1621            self.transport.abortConnection()
1622        self.inReactorMethod = False
1623
1624
1625    def stopProducing(self):
1626        self.producerStopped = True
1627
1628
1629    def connectionLost(self, reason):
1630        if not self.producerStopped:
1631            raise RuntimeError("BUG: stopProducing() was never called.")
1632        if self.inReactorMethod:
1633            raise RuntimeError("BUG: connectionLost called re-entrantly!")
1634        ConnectableProtocol.connectionLost(self, reason)
1635
1636
1637
1638class StreamingProducerClient(ConnectableProtocol):
1639    """
1640    Call abortConnection() when the other side has stopped reading.
1641
1642    In particular, we want to call abortConnection() only once our local
1643    socket hits a state where it is no longer writeable. This helps emulate
1644    the most common use case for abortConnection(), closing a connection after
1645    a timeout, with write buffers being full.
1646
1647    Since it's very difficult to know when this actually happens, we just
1648    write a lot of data, and assume at that point no more writes will happen.
1649    """
1650    paused = False
1651    extraWrites = 0
1652    inReactorMethod = False
1653
1654    def connectionMade(self):
1655        self.write()
1656
1657
1658    def write(self):
1659        """
1660        Write large amount to transport, then wait for a while for buffers to
1661        fill up.
1662        """
1663        self.transport.registerProducer(self, True)
1664        for i in range(100):
1665            self.transport.write("1234567890" * 32000)
1666
1667
1668    def resumeProducing(self):
1669        self.paused = False
1670
1671
1672    def stopProducing(self):
1673        pass
1674
1675
1676    def pauseProducing(self):
1677        """
1678        Called when local buffer fills up.
1679
1680        The goal is to hit the point where the local file descriptor is not
1681        writeable (or the moral equivalent). The fact that pauseProducing has
1682        been called is not sufficient, since that can happen when Twisted's
1683        buffers fill up but OS hasn't gotten any writes yet. We want to be as
1684        close as possible to every buffer (including OS buffers) being full.
1685
1686        So, we wait a bit more after this for Twisted to write out a few
1687        chunks, then abortConnection.
1688        """
1689        if self.paused:
1690            return
1691        self.paused = True
1692        # The amount we wait is arbitrary, we just want to make sure some
1693        # writes have happened and outgoing OS buffers filled up -- see
1694        # http://twistedmatrix.com/trac/ticket/5303 for details:
1695        self.reactor.callLater(0.01, self.doAbort)
1696
1697
1698    def doAbort(self):
1699        if not self.paused:
1700            log.err(RuntimeError("BUG: We should be paused a this point."))
1701        self.inReactorMethod = True
1702        self.transport.abortConnection()
1703        self.inReactorMethod = False
1704
1705
1706    def connectionLost(self, reason):
1707        # Tell server to start reading again so it knows to go away:
1708        self.otherProtocol.transport.startReading()
1709        ConnectableProtocol.connectionLost(self, reason)
1710
1711
1712
1713class StreamingProducerClientLater(StreamingProducerClient):
1714    """
1715    Call abortConnection() from dataReceived, after bytes have been
1716    exchanged.
1717    """
1718
1719    def connectionMade(self):
1720        self.transport.write("hello")
1721        self.gotData = False
1722
1723
1724    def dataReceived(self, data):
1725        if not self.gotData:
1726            self.gotData = True
1727            self.write()
1728
1729
1730class ProducerAbortingClientLater(ProducerAbortingClient):
1731    """
1732    Call abortConnection from doWrite, via resumeProducing.
1733
1734    Try to do so after some bytes have already been exchanged, so we
1735    don't interrupt SSL handshake.
1736    """
1737
1738    def connectionMade(self):
1739        # Override base class connectionMade().
1740        pass
1741
1742
1743    def dataReceived(self, data):
1744        self.write()
1745
1746
1747
1748class DataReceivedRaisingClient(AbortingClient):
1749    """
1750    Call abortConnection(), and then throw exception, from dataReceived.
1751    """
1752
1753    def dataReceived(self, data):
1754        self.transport.abortConnection()
1755        raise ZeroDivisionError("ONO")
1756
1757
1758
1759class ResumeThrowsClient(ProducerAbortingClient):
1760    """
1761    Call abortConnection() and throw exception from resumeProducing().
1762    """
1763
1764    def resumeProducing(self):
1765        if not self.inRegisterProducer:
1766            self.transport.abortConnection()
1767            raise ZeroDivisionError("ono!")
1768
1769
1770    def connectionLost(self, reason):
1771        # Base class assertion about stopProducing being called isn't valid;
1772        # if the we blew up in resumeProducing, consumers are justified in
1773        # giving up on the producer and not calling stopProducing.
1774        ConnectableProtocol.connectionLost(self, reason)
1775
1776
1777
1778class AbortConnectionMixin(object):
1779    """
1780    Unit tests for L{ITransport.abortConnection}.
1781    """
1782    # Override in subclasses, should be a EndpointCreator instance:
1783    endpoints = None
1784
1785    def runAbortTest(self, clientClass, serverClass,
1786                     clientConnectionLostReason=None):
1787        """
1788        A test runner utility function, which hooks up a matched pair of client
1789        and server protocols.
1790
1791        We then run the reactor until both sides have disconnected, and then
1792        verify that the right exception resulted.
1793        """
1794        clientExpectedExceptions = (ConnectionAborted, ConnectionLost)
1795        serverExpectedExceptions = (ConnectionLost, ConnectionDone)
1796        # In TLS tests we may get SSL.Error instead of ConnectionLost,
1797        # since we're trashing the TLS protocol layer.
1798        if useSSL:
1799            clientExpectedExceptions = clientExpectedExceptions + (SSL.Error,)
1800            serverExpectedExceptions = serverExpectedExceptions + (SSL.Error,)
1801
1802        client = clientClass()
1803        server = serverClass()
1804        client.otherProtocol = server
1805        server.otherProtocol = client
1806        reactor = runProtocolsWithReactor(self, server, client, self.endpoints)
1807
1808        # Make sure everything was shutdown correctly:
1809        self.assertEqual(reactor.removeAll(), [])
1810        # The reactor always has a timeout added in runReactor():
1811        delayedCalls = reactor.getDelayedCalls()
1812        self.assertEqual(len(delayedCalls), 1, map(str, delayedCalls))
1813
1814        if clientConnectionLostReason is not None:
1815            self.assertIsInstance(
1816                client.disconnectReason.value,
1817                (clientConnectionLostReason,) + clientExpectedExceptions)
1818        else:
1819            self.assertIsInstance(client.disconnectReason.value,
1820                                  clientExpectedExceptions)
1821        self.assertIsInstance(server.disconnectReason.value, serverExpectedExceptions)
1822
1823
1824    def test_dataReceivedAbort(self):
1825        """
1826        abortConnection() is called in dataReceived. The protocol should be
1827        disconnected, but connectionLost should not be called re-entrantly.
1828        """
1829        return self.runAbortTest(AbortingClient, ReadAbortServerProtocol)
1830
1831
1832    def test_clientAbortsConnectionTwice(self):
1833        """
1834        abortConnection() is called twice by client.
1835
1836        No exception should be thrown, and the connection will be closed.
1837        """
1838        return self.runAbortTest(AbortingTwiceClient, ReadAbortServerProtocol)
1839
1840
1841    def test_clientAbortsConnectionThenLosesConnection(self):
1842        """
1843        Client calls abortConnection(), followed by loseConnection().
1844
1845        No exception should be thrown, and the connection will be closed.
1846        """
1847        return self.runAbortTest(AbortingThenLosingClient,
1848                                 ReadAbortServerProtocol)
1849
1850
1851    def test_serverAbortsConnectionTwice(self):
1852        """
1853        abortConnection() is called twice by server.
1854
1855        No exception should be thrown, and the connection will be closed.
1856        """
1857        return self.runAbortTest(WritingButNotAbortingClient, ServerAbortsTwice,
1858                                 clientConnectionLostReason=ConnectionLost)
1859
1860
1861    def test_serverAbortsConnectionThenLosesConnection(self):
1862        """
1863        Server calls abortConnection(), followed by loseConnection().
1864
1865        No exception should be thrown, and the connection will be closed.
1866        """
1867        return self.runAbortTest(WritingButNotAbortingClient,
1868                                 ServerAbortsThenLoses,
1869                                 clientConnectionLostReason=ConnectionLost)
1870
1871
1872    def test_resumeProducingAbort(self):
1873        """
1874        abortConnection() is called in resumeProducing, before any bytes have
1875        been exchanged. The protocol should be disconnected, but
1876        connectionLost should not be called re-entrantly.
1877        """
1878        self.runAbortTest(ProducerAbortingClient,
1879                          ConnectableProtocol)
1880
1881
1882    def test_resumeProducingAbortLater(self):
1883        """
1884        abortConnection() is called in resumeProducing, after some
1885        bytes have been exchanged. The protocol should be disconnected.
1886        """
1887        return self.runAbortTest(ProducerAbortingClientLater,
1888                                 AbortServerWritingProtocol)
1889
1890
1891    def test_fullWriteBuffer(self):
1892        """
1893        abortConnection() triggered by the write buffer being full.
1894
1895        In particular, the server side stops reading. This is supposed
1896        to simulate a realistic timeout scenario where the client
1897        notices the server is no longer accepting data.
1898
1899        The protocol should be disconnected, but connectionLost should not be
1900        called re-entrantly.
1901        """
1902        self.runAbortTest(StreamingProducerClient,
1903                          NoReadServer)
1904
1905
1906    def test_fullWriteBufferAfterByteExchange(self):
1907        """
1908        abortConnection() is triggered by a write buffer being full.
1909
1910        However, this buffer is filled after some bytes have been exchanged,
1911        allowing a TLS handshake if we're testing TLS. The connection will
1912        then be lost.
1913        """
1914        return self.runAbortTest(StreamingProducerClientLater,
1915                                 EventualNoReadServer)
1916
1917
1918    def test_dataReceivedThrows(self):
1919        """
1920        dataReceived calls abortConnection(), and then raises an exception.
1921
1922        The connection will be lost, with the thrown exception
1923        (C{ZeroDivisionError}) as the reason on the client. The idea here is
1924        that bugs should not be masked by abortConnection, in particular
1925        unexpected exceptions.
1926        """
1927        self.runAbortTest(DataReceivedRaisingClient,
1928                          AbortServerWritingProtocol,
1929                          clientConnectionLostReason=ZeroDivisionError)
1930        errors = self.flushLoggedErrors(ZeroDivisionError)
1931        self.assertEquals(len(errors), 1)
1932
1933
1934    def test_resumeProducingThrows(self):
1935        """
1936        resumeProducing calls abortConnection(), and then raises an exception.
1937
1938        The connection will be lost, with the thrown exception
1939        (C{ZeroDivisionError}) as the reason on the client. The idea here is
1940        that bugs should not be masked by abortConnection, in particular
1941        unexpected exceptions.
1942        """
1943        self.runAbortTest(ResumeThrowsClient,
1944                          ConnectableProtocol,
1945                          clientConnectionLostReason=ZeroDivisionError)
1946        errors = self.flushLoggedErrors(ZeroDivisionError)
1947        self.assertEquals(len(errors), 1)
1948
1949
1950
1951class AbortConnectionTestCase(ReactorBuilder, AbortConnectionMixin):
1952    """
1953    TCP-specific L{AbortConnectionMixin} tests.
1954    """
1955
1956    endpoints = TCPCreator()
1957
1958globals().update(AbortConnectionTestCase.makeTestCaseClasses())
1959
1960
1961
1962class SimpleUtilityTestCase(TestCase):
1963    """
1964    Simple, direct tests for helpers within L{twisted.internet.tcp}.
1965    """
1966
1967    skip = ipv6Skip
1968
1969    def test_resolveNumericHost(self):
1970        """
1971        L{_resolveIPv6} raises a L{socket.gaierror} (L{socket.EAI_NONAME}) when
1972        invoked with a non-numeric host.  (In other words, it is passing
1973        L{socket.AI_NUMERICHOST} to L{socket.getaddrinfo} and will not
1974        accidentally block if it receives bad input.)
1975        """
1976        err = self.assertRaises(socket.gaierror, _resolveIPv6, "localhost", 1)
1977        self.assertEqual(err.args[0], socket.EAI_NONAME)
1978
1979
1980    def test_resolveNumericService(self):
1981        """
1982        L{_resolveIPv6} raises a L{socket.gaierror} (L{socket.EAI_NONAME}) when
1983        invoked with a non-numeric port.  (In other words, it is passing
1984        L{socket.AI_NUMERICSERV} to L{socket.getaddrinfo} and will not
1985        accidentally block if it receives bad input.)
1986        """
1987        err = self.assertRaises(socket.gaierror, _resolveIPv6, "::1", "http")
1988        self.assertEqual(err.args[0], socket.EAI_NONAME)
1989
1990    if platform.isWindows():
1991        test_resolveNumericService.skip = ("The AI_NUMERICSERV flag is not "
1992                                           "supported by Microsoft providers.")
1993        # http://msdn.microsoft.com/en-us/library/windows/desktop/ms738520.aspx
1994
1995
1996    def test_resolveIPv6(self):
1997        """
1998        L{_resolveIPv6} discovers the flow info and scope ID of an IPv6
1999        address.
2000        """
2001        result = _resolveIPv6("::1", 2)
2002        self.assertEqual(len(result), 4)
2003        # We can't say anything more useful about these than that they're
2004        # integers, because the whole point of getaddrinfo is that you can never
2005        # know a-priori know _anything_ about the network interfaces of the
2006        # computer that you're on and you have to ask it.
2007        self.assertIsInstance(result[2], int) # flow info
2008        self.assertIsInstance(result[3], int) # scope id
2009        # but, luckily, IP presentation format and what it means to be a port
2010        # number are a little better specified.
2011        self.assertEqual(result[:2], ("::1", 2))
2012
2013
2014