Ticket #2060: iaddress-connect-listen2.diff

File iaddress-connect-listen2.diff, 9.7 KB (added by rwall, 8 years ago)

A updated diff requires reactor to be passed as an argument and without redefining backlog default value; plus an SSL option, thanks to exarkun and tv for ideas.

  • twisted/test/test_unix.py

     
    299299 
    300300        return defer.maybeDeferred(p.stopListening).addCallback(stoppedListening) 
    301301 
     302class UNIXAddressTestCase(PortCleanerUpper): 
     303    def testConnect(self): 
     304        """Test that UNIXAddress.connect returns an IConnector""" 
     305        socketPath = self.mktemp() 
     306        listeningPort = reactor.listenUNIX(socketPath, Factory(self, socketPath)) 
     307        self.ports.append(listeningPort) 
     308        addr = address.UNIXAddress(socketPath) 
     309        conn = addr.connect(reactor, TestClientFactory(self, socketPath)) 
     310        self.assertTrue(interfaces.IConnector.providedBy(conn)) 
    302311 
     312    def testListen(self): 
     313        """Test that UNIXAddress.listen returns an IListeningPort""" 
     314        socketPath = self.mktemp() 
     315        addr = address.UNIXAddress(socketPath) 
     316        listeningPort = addr.listen(reactor, Factory(self, socketPath)) 
     317        self.ports.append(listeningPort) 
     318        self.assertTrue(interfaces.IListeningPort.providedBy(listeningPort)) 
     319 
    303320if not interfaces.IReactorUNIX(reactor, None): 
    304321    UnixSocketTestCase.skip = "This reactor does not support UNIX domain sockets" 
    305322if not interfaces.IReactorUNIXDatagram(reactor, None): 
  • twisted/test/test_tcp.py

     
    1515from twisted.internet import protocol, reactor, defer, interfaces 
    1616from twisted.internet import error 
    1717from twisted.internet.address import IPv4Address 
    18 from twisted.internet.interfaces import IHalfCloseableProtocol 
     18from twisted.internet.interfaces import (IConnector, IHalfCloseableProtocol,  
     19                                         IListeningPort) 
    1920from twisted.protocols import policies 
    2021 
    2122 
     
    11421143        d.addCallback(lambda _: log.flushErrors(RuntimeError)) 
    11431144        return d 
    11441145 
     1146class IPv4AddressTestCase(PortCleanerUpper): 
     1147    def testConnect(self): 
     1148        """Test that IPv4Address.connect returns an IConnector""" 
     1149        factory = ClosingFactory() 
     1150        listeningPort = reactor.listenTCP(0, factory) 
     1151        self.ports.append(listeningPort) 
     1152        factory.port = listeningPort 
     1153        portNo = listeningPort.getHost().port 
     1154        addr = IPv4Address("TCP", "127.0.0.1", portNo) 
     1155        conn = addr.connect(reactor, MyClientFactory()) 
     1156        self.assertTrue(IConnector.providedBy(conn)) 
    11451157 
     1158    def testListen(self): 
     1159        """Test that IPv4Address.listen returns an IListeningPort""" 
     1160        factory = ClosingFactory() 
     1161        addr = IPv4Address("TCP", "0.0.0.0", 0) 
     1162        listeningPort = addr.listen(reactor, factory) 
     1163        self.ports.append(listeningPort) 
     1164        factory.port = listeningPort 
     1165        self.assertTrue(IListeningPort.providedBy(listeningPort)) 
     1166 
    11461167try: 
    11471168    import resource 
    11481169except ImportError: 
  • twisted/test/test_sslverify.py

     
    77from OpenSSL.crypto import TYPE_RSA 
    88 
    99from twisted.trial import unittest 
    10 from twisted.internet import protocol, defer, reactor 
     10from twisted.internet import address, protocol, defer, reactor, ssl 
    1111from twisted.python import log 
    1212 
    1313from twisted.internet import _sslverify as sslverify 
     
    425425            sslverify.Certificate.peerFromTransport( 
    426426                _ActualSSLTransport()).serialNumber(), 
    427427            12346) 
     428 
     429class IPv4AddressTestCase(unittest.TestCase): 
     430    serverPort = clientConn = None 
     431    onServerLost = onClientLost = None 
     432 
     433    def setUpClass(self): 
     434        self.sKey, self.sCert = makeCertificate( 
     435            O="Server Test Certificate", 
     436            CN="server") 
     437        self.cKey, self.cCert = makeCertificate( 
     438            O="Client Test Certificate", 
     439            CN="client") 
     440 
     441    def tearDown(self): 
     442        if self.serverPort is not None: 
     443            self.serverPort.stopListening() 
     444        if self.clientConn is not None: 
     445            self.clientConn.disconnect() 
     446 
     447        L = [] 
     448        if self.onServerLost is not None: 
     449            L.append(self.onServerLost) 
     450        if self.onClientLost is not None: 
     451            L.append(self.onClientLost) 
     452 
     453        return defer.DeferredList(L, consumeErrors=True) 
     454 
     455    def testIAddressConnectAndListen(self): 
     456        """Test that passing an SSL context factory to IPv4Address results in an 
     457        SSL connection. Based on OpenSSLOptions.loopback and  
     458        OpenSSLOptions.testAllowedAnonymousClientConnection 
     459        """ 
     460        self.onServerLost = defer.Deferred() 
     461        self.onClientLost = defer.Deferred() 
     462        onData = defer.Deferred() 
     463 
     464        serverCertOpts = sslverify.OpenSSLCertificateOptions(privateKey=self.sKey, certificate=self.sCert, requireCertificate=False) 
     465        clientCertOpts = sslverify.OpenSSLCertificateOptions(requireCertificate=False) 
     466 
     467        serverFactory = protocol.ServerFactory() 
     468        serverFactory.protocol = DataCallbackProtocol 
     469        serverFactory.onLost = self.onServerLost 
     470        serverFactory.onData = onData 
     471 
     472        clientFactory = protocol.ClientFactory() 
     473        clientFactory.protocol = WritingProtocol 
     474        clientFactory.onLost = self.onClientLost 
     475 
     476        serverAddr = address.IPv4Address("TCP", "127.0.0.1", 0, serverCertOpts) 
     477        self.serverPort = serverAddr.listen(reactor, serverFactory) 
     478        clientAddr = address.IPv4Address("TCP", "127.0.0.1", self.serverPort.getHost().port, clientCertOpts) 
     479        self.clientConn = clientAddr.connect(reactor, clientFactory) 
     480         
     481        self.assertTrue(isinstance(self.serverPort, ssl.Port)) 
     482        self.assertTrue(isinstance(self.clientConn, ssl.Connector)) 
     483         
     484        return onData.addCallback( 
     485            lambda result: self.assertEquals(result, WritingProtocol.byte)) 
  • twisted/internet/interfaces.py

     
    1919    Default implementations are in L{twisted.internet.address}. 
    2020    """ 
    2121 
    22  
     22    def connect(reactor, factory, timeout=30): 
     23        """Attempt to connect to my address 
     24         
     25        @param reactor: The reactor instance providing one or more of: 
     26                        L{IReactorTCP}, L{IReactorUNIX}, L{IReactorUDP} 
     27        @param factory: A protocol factory providing L{IProtocolFactory} 
     28        @return: An object which provides L{IConnector}. 
     29        """ 
     30         
     31    def listen(reactor, factory): 
     32        """Attempt to listen at my address 
     33         
     34        @param reactor: The reactor instance providing one or more of: 
     35                        L{IReactorTCP}, L{IReactorUNIX}, L{IReactorUDP} 
     36        @param factory: A protocol factory providing L{IProtocolFactory} 
     37        @return: An object which provides L{IListeningPort}. 
     38        """         
    2339### Reactor Interfaces 
    2440 
    2541class IConnector(Interface): 
  • twisted/internet/address.py

     
    88from zope.interface import implements 
    99from twisted.internet.interfaces import IAddress 
    1010 
     11class _NO_BACKLOG_GIVEN:  
     12    pass 
    1113 
    1214class IPv4Address(object): 
    1315    """ 
     
    2729 
    2830    implements(IAddress) 
    2931     
    30     def __init__(self, type, host, port, _bwHack = None): 
     32    def __init__(self, type, host, port, sslContextFactory=None, _bwHack = None): 
    3133        assert type in ('TCP', 'UDP') 
    3234        self.type = type 
    3335        self.host = host 
    3436        self.port = port 
     37        self.sslContextFactory = sslContextFactory 
    3538        self._bwHack = _bwHack 
    3639 
     40    def connect(self, reactor, factory, timeout=30): 
     41        if self.sslContextFactory: 
     42            return reactor.connectSSL(self.host, self.port, factory,  
     43                                      self.sslContextFactory, timeout) 
     44        else: 
     45            return reactor.connectTCP(self.host, self.port, factory, timeout) 
     46 
     47    def listen(self, reactor, factory, backlog=_NO_BACKLOG_GIVEN): 
     48        kwargs = {} 
     49        if backlog is not _NO_BACKLOG_GIVEN: 
     50            kwargs["backlog"] = backlog 
     51        if self.sslContextFactory: 
     52            return reactor.listenSSL(self.port, factory, self.sslContextFactory,  
     53                                     interface=self.host, **kwargs) 
     54        else: 
     55            return reactor.listenTCP(self.port,  
     56                                     factory, interface=self.host, **kwargs) 
     57 
    3758    def __getitem__(self, index): 
    3859        warnings.warn("IPv4Address.__getitem__ is deprecated.  Use attributes instead.", 
    3960                      category=DeprecationWarning, stacklevel=2) 
     
    7091    def __init__(self, name, _bwHack='UNIX'): 
    7192        self.name = name 
    7293        self._bwHack = _bwHack 
    73      
     94 
     95    def connect(self, reactor, factory, timeout=30): 
     96        return reactor.connectUNIX(self.name, factory, timeout) 
     97 
     98    def listen(self, reactor, factory, backlog=_NO_BACKLOG_GIVEN): 
     99        kwargs = {} 
     100        if backlog is not _NO_BACKLOG_GIVEN: 
     101            kwargs["backlog"] = backlog 
     102        return reactor.listenUNIX(self.name, factory, **kwargs) 
     103 
    74104    def __getitem__(self, index): 
    75105        warnings.warn("UNIXAddress.__getitem__ is deprecated.  Use attributes instead.", 
    76106                      category=DeprecationWarning, stacklevel=2) 
     
    92122        return False 
    93123 
    94124    def __str__(self): 
    95         return 'UNIXSocket(%r)' % (self.name,) 
     125        return 'UNIXAddress(%r)' % (self.name,) 
    96126 
    97127 
    98128# These are for buildFactory backwards compatability due to