Ticket #1442: endpoints3.diff

File endpoints3.diff, 14.6 KB (added by rwall, 8 years ago)

This time with connect and listen united, a simpler wrapping protocol, and documented IAddress.buildEndpoint

  • twisted/test/test_endpoints.py

     
     1from twisted.trial import unittest 
     2from twisted.internet import (defer, error, interfaces, reactor,  
     3                              _sslverify as sslverify) 
     4from twisted.internet.address import IPv4Address, UNIXAddress 
     5from twisted.internet.protocol import ClientFactory, Protocol, ServerFactory 
     6from twisted.internet.endpoints import (TCPEndpoint, UNIXEndpoint) 
     7from twisted.test.test_sslverify import makeCertificate 
     8 
     9class ServerProtocol(Protocol): 
     10    def connectionMade(self): 
     11        self.factory.onConnectionMade.callback(self) 
     12 
     13    def connectionLost(self, *a): 
     14        self.factory.onConnectionLost.callback(self) 
     15 
     16class ClientProtocol(Protocol): 
     17    def connectionMade(self): 
     18        self.factory.onConnectionMade.callback(self) 
     19 
     20    def connectionLost(self, *a): 
     21        self.factory.onConnectionLost.callback(self) 
     22 
     23class MyServerFactory(ServerFactory): 
     24    protocol = ServerProtocol 
     25     
     26    def __init__(self): 
     27        self.onConnectionMade = defer.Deferred() 
     28        self.onConnectionLost = defer.Deferred() 
     29 
     30class MyClientFactory(ClientFactory): 
     31    protocol = ClientProtocol 
     32     
     33    def __init__(self): 
     34        self.onConnectionMade = defer.Deferred() 
     35        self.onConnectionLost = defer.Deferred() 
     36 
     37class PortAndConnectorCleanerUpper(unittest.TestCase): 
     38    def setUp(self): 
     39        self.listeningPorts = [] 
     40        self.clientConnections = [] 
     41     
     42    def tearDown(self): 
     43        map(lambda p: p.stopListening(), self.listeningPorts) 
     44        map(lambda c: c.disconnect(), self.clientConnections) 
     45 
     46class EndpointTestCaseMixin(object): 
     47    def test_EndpointConnectSuccess(self): 
     48        """Test that Endpoint can connect and returns a deferred who 
     49        gets called back with a protocol instance.  
     50        """ 
     51        sf = MyServerFactory() 
     52        p = self.createServer(sf) 
     53        addr = p.getHost() 
     54        ep = self.createEndpoint(addr) 
     55        cf = MyClientFactory() 
     56        d = ep.connect(reactor, cf) 
     57        self.assertTrue(isinstance(d, defer.Deferred)) 
     58        def onConnectSuccess(proto): 
     59            self.assertTrue(interfaces.IProtocol.providedBy(proto)) 
     60            proto.transport.loseConnection() 
     61        d.addCallback(onConnectSuccess) 
     62        return defer.gatherResults([sf.onConnectionMade, cf.onConnectionMade, d]) 
     63         
     64    def test_EndpointConnectFailure(self): 
     65        """Test that if an Endpoint tries to connect to a none  
     66        listening port that it gets a ConnectError failure. 
     67        """ 
     68        p = self.createServer(MyServerFactory()) 
     69        addr = p.getHost() 
     70        p.loseConnection() 
     71 
     72        ep = self.createEndpoint(addr) 
     73        d = ep.connect(reactor, MyClientFactory()) 
     74        self.failUnlessFailure(d, error.ConnectError) 
     75        return d 
     76     
     77    def test_EndpointListenSuccess(self): 
     78        """Test that Endpoint can listen and returns a deferred that 
     79        gets called back with a port instance.  
     80        """ 
     81        ep = self.createEndpoint() 
     82        sf = MyServerFactory() 
     83        d = ep.listen(reactor, sf) 
     84        self.assertTrue(isinstance(d, defer.Deferred)) 
     85        def onListenSuccess(port): 
     86            self.assertTrue(interfaces.IListeningPort.providedBy(port)) 
     87            self.listeningPorts.append(port) 
     88            return port.getHost() 
     89        d.addCallback(onListenSuccess) 
     90        def connectTo(addr): 
     91            self.createClient(addr, MyClientFactory()) 
     92        d.addCallback(connectTo) 
     93        return defer.gatherResults([sf.onConnectionMade, d]) 
     94 
     95    def test_EndpointListenFailure(self): 
     96        """Test that if Endpoint tries to listen on an already listening  
     97        port, that a CannotListenError failure is errbacked.  
     98        """ 
     99        p = self.createServer(MyServerFactory()) 
     100        addr = p.getHost() 
     101        ep = self.createEndpoint(addr) 
     102        d = ep.listen(reactor, MyServerFactory()) 
     103        self.failUnlessFailure(d, error.CannotListenError) 
     104        return d 
     105 
     106class TCPEndpointsTestCase(PortAndConnectorCleanerUpper, EndpointTestCaseMixin): 
     107    def createServer(self, factory): 
     108        p = reactor.listenTCP(0, factory) 
     109        self.listeningPorts.append(p) 
     110        return p 
     111 
     112    def createClient(self, address, factory): 
     113        c = reactor.connectTCP(address.host, address.port, factory) 
     114        self.clientConnections.append(c) 
     115        return c 
     116 
     117    def createEndpoint(self, address=None): 
     118        if not address: 
     119            address = IPv4Address("TCP", "localhost", 0) 
     120        return TCPEndpoint(address.host, address.port) 
     121 
     122class SSLEndpointsTestCase(PortAndConnectorCleanerUpper, EndpointTestCaseMixin): 
     123     
     124    def setUpClass(self): 
     125        self.sKey, self.sCert = makeCertificate( 
     126            O="Server Test Certificate", 
     127            CN="server") 
     128        self.cKey, self.cCert = makeCertificate( 
     129            O="Client Test Certificate", 
     130            CN="client") 
     131        self.serverSSLContext = sslverify.OpenSSLCertificateOptions(privateKey=self.sKey, certificate=self.sCert, requireCertificate=False) 
     132        self.clientSSLContext = sslverify.OpenSSLCertificateOptions(requireCertificate=False) 
     133         
     134    def createServer(self, factory): 
     135        p = reactor.listenSSL(0, factory, self.serverSSLContext) 
     136        self.listeningPorts.append(p) 
     137        return p 
     138 
     139    def createClient(self, address, factory): 
     140        c = reactor.connectSSL(address.host, address.port, factory,  
     141                               self.clientSSLContext) 
     142        self.clientConnections.append(c) 
     143        return c 
     144 
     145    def createEndpoint(self, address=None): 
     146        if not address: 
     147            address = IPv4Address("TCP", "localhost", 0) 
     148        return TCPEndpoint(address.host, address.port,  
     149                           sslContextFactory=self.clientSSLContext) 
     150         
     151 
     152class UNIXEndpointsTestCase(PortAndConnectorCleanerUpper, EndpointTestCaseMixin): 
     153    def createServer(self, factory): 
     154        p = reactor.listenUNIX(self.mktemp(), factory) 
     155        self.listeningPorts.append(p) 
     156        return p 
     157 
     158    def createClient(self, address, factory): 
     159        c = reactor.connectUNIX(address.name, factory) 
     160        self.clientConnections.append(c) 
     161        return c 
     162 
     163    def createEndpoint(self, address=None): 
     164        if not address: 
     165            address = UNIXAddress(self.mktemp()) 
     166        return UNIXEndpoint(address.name) 
  • twisted/internet/endpoints.py

     
     1# -*- test-case-name: twisted.test.test_endpoints -*- 
     2 
     3from zope.interface import implements, providedBy, directlyProvides 
     4 
     5from twisted.internet import interfaces 
     6from twisted.internet import defer, protocol 
     7from twisted.internet.protocol import ClientFactory, Protocol 
     8 
     9class _WrappingProtocol(Protocol): 
     10    """I wrap another protocol in order to notify my user when a connection has  
     11    been made. 
     12    """ 
     13    def __init__(self, factory, wrappedProtocol): 
     14        self.factory = factory 
     15        self.wrappedProtocol = wrappedProtocol 
     16         
     17    def connectionMade(self): 
     18        """XXX: As soon as I am connected, I connect my wrappedProtocol, giving  
     19        it my transport. Is it okay for a transport to be associated with more 
     20        than one protocol? Transport calls dataReceived on me and I in turn call 
     21        dataReceived on my wrappedProtocol. The wrappedProtocol may call  
     22        transport.write or transport.loseConnection etc 
     23        """ 
     24         
     25        self.wrappedProtocol.makeConnection(self.transport) 
     26        self.factory.deferred.callback(self.wrappedProtocol) 
     27         
     28    def dataReceived(self, data): 
     29        return self.wrappedProtocol.dataReceived(data) 
     30 
     31    def connectionLost(self, reason): 
     32        return self.wrappedProtocol.connectionLost(reason) 
     33         
     34class _WrappingFactory(ClientFactory): 
     35    protocol = _WrappingProtocol 
     36 
     37    def __init__(self, wrappedFactory): 
     38        self.wrappedFactory = wrappedFactory 
     39        self.deferred = defer.Deferred() 
     40 
     41    def buildProtocol(self, addr): 
     42        try: 
     43            proto = self.wrappedFactory.buildProtocol(addr) 
     44        except: 
     45            self.deferred.errback() 
     46        else: 
     47            return self.protocol(self, proto) 
     48 
     49    def clientConnectionFailed(self, connector, reason): 
     50        self.deferred.errback(reason) 
     51 
     52 
     53class TCPEndpoint(object): 
     54    implements(interfaces.IClientEndpoint, interfaces.IServerEndpoint) 
     55 
     56    def __init__(self, host, port, connectArgs={}, listenArgs={}, 
     57                 sslContextFactory=None): 
     58        """ 
     59        @param host: A hostname, used only when connecting 
     60        @param port: The port number, used both when connecting and listening 
     61        @param connectArgs: An optional dict of keyword args that will be passed 
     62        to L{twisted.internet.interfaces.IReactorTCP.connectTCP} 
     63        @param listenArgs: An optional dict of keyword args that will be passed  
     64        to L{twisted.internet.interfaces.IReactorTCP.listenTCP} 
     65        @param sslContextFactory: An optional instance of  
     66        L{twisted.internet._sslverify.OpenSSLCertificateOptions}. If given, it  
     67        makes L{connect} and L{listen} use the corresponding methods from  
     68        L{twisted.internet.interfaces.IReactorSSL} 
     69        """ 
     70        self.host = host 
     71        self.port = port 
     72        self.connectArgs = dict(timeout=30, bindAddress=None) 
     73        self.connectArgs.update(connectArgs) 
     74        self.listenArgs = dict(backlog=50, interface='') 
     75        self.listenArgs.update(listenArgs) 
     76        self.sslContextFactory = sslContextFactory 
     77         
     78    def connect(self, reactor, clientFactory): 
     79        wf = _WrappingFactory(clientFactory) 
     80        connectArgs = self.connectArgs 
     81        connectMethod = reactor.connectTCP 
     82        if self.sslContextFactory: 
     83            connectMethod = reactor.connectSSL 
     84            connectArgs["contextFactory"] = self.sslContextFactory 
     85             
     86        d = defer.execute(connectMethod, self.host, self.port, wf, **connectArgs) 
     87 
     88        d.addCallback(lambda _: wf.deferred) 
     89 
     90        return d 
     91 
     92    def listen(self, reactor, serverFactory): 
     93        wf = _WrappingFactory(serverFactory) 
     94        listenArgs = self.listenArgs 
     95        listenMethod = reactor.listenTCP 
     96        if self.sslContextFactory: 
     97            listenMethod = reactor.listenSSL 
     98            listenArgs["contextFactory"] = self.sslContextFactory 
     99        return defer.execute(listenMethod, self.port, wf, **listenArgs) 
     100 
     101class UNIXEndpoint(object): 
     102    implements(interfaces.IClientEndpoint, interfaces.IServerEndpoint) 
     103 
     104    def __init__(self, address, connectArgs={}, listenArgs={}): 
     105        """ 
     106        @param address: The path to the Unix socket file, used both when  
     107        connecting and listening 
     108        @param connectArgs: A dict of keyword args that will be passed to  
     109        L{twisted.internet.interfaces.IReactorUNIX.connectUNIX} 
     110        @param listenArgs: A dict of keyword args that will be passed to  
     111        L{twisted.internet.interfaces.IReactorUNIX.listenUNIX} 
     112        """ 
     113 
     114        self.address = address 
     115        self.connectArgs = dict(timeout=30, checkPID=0) 
     116        self.connectArgs.update(connectArgs) 
     117        self.listenArgs = dict(backlog=50, mode=0666, wantPID=0) 
     118        self.listenArgs.update(listenArgs) 
     119 
     120    def connect(self, reactor, clientFactory): 
     121        wf = _WrappingFactory(clientFactory) 
     122        d = defer.execute(reactor.connectUNIX, self.address, wf,  
     123                          **self.connectArgs) 
     124 
     125        d.addCallback(lambda _: wf.deferred) 
     126 
     127        return d 
     128 
     129    def listen(self, reactor, serverFactory): 
     130        wf = _WrappingFactory(serverFactory) 
     131        return defer.execute(reactor.listenUNIX, self.address, wf,  
     132                             **self.listenArgs) 
  • twisted/internet/interfaces.py

     
    1818 
    1919    Default implementations are in L{twisted.internet.address}. 
    2020    """ 
     21     
     22    def buildEndpoint(): 
     23        """ 
     24        @return: an instance providing both L{IClientEndpoint} and L{IServerEndpoint} 
     25        """ 
    2126 
    2227 
    2328### Reactor Interfaces 
     
    12941299 
    12951300    def leaveGroup(addr, interface=""): 
    12961301        """Leave multicast group, return Deferred of success.""" 
     1302 
     1303class IClientEndpoint(Interface): 
     1304    """Object that represents a remote endpoint that we wish to connect to. 
     1305    """ 
     1306    def connect(reactor, clientFactory): 
     1307        """ 
     1308        @param reactor: The reactor 
     1309        @param clientFactory: A provider of L{IProtocolFactory} 
     1310        @return: A L{Deferred} that results in an L{IProtocol} upon successful 
     1311        connection otherwise a L{ConnectError} 
     1312        """ 
     1313 
     1314class IServerEndpoint(Interface): 
     1315    """Object representing an endpoint where we will listen for connections. 
     1316    """ 
     1317 
     1318    def listen(callable): 
     1319        """ 
     1320        @param reactor: The reactor 
     1321        @param serverFactory: A provider of L{IProtocolFactory} 
     1322        @return: A L{Deferred} that results in an L{IListeningPort} or an  
     1323        L{CannotListenError} 
     1324        """ 
  • twisted/internet/address.py

     
    55"""Address objects for network connections.""" 
    66 
    77import warnings, os 
     8 
    89from zope.interface import implements 
     10 
    911from twisted.internet.interfaces import IAddress 
     12from twisted.internet.endpoints import TCPEndpoint, UNIXEndpoint 
    1013 
    11  
    1214class IPv4Address(object): 
    1315    """ 
    1416    Object representing an IPv4 socket endpoint. 
     
    3436        self.port = port 
    3537        self._bwHack = _bwHack 
    3638 
     39    def buildEndpoint(self): 
     40        if self.type == "TCP": 
     41            return TCPEndpoint(self.host, self.port) 
     42        else: 
     43            raise NotImplementedError 
     44 
    3745    def __getitem__(self, index): 
    3846        warnings.warn("IPv4Address.__getitem__ is deprecated.  Use attributes instead.", 
    3947                      category=DeprecationWarning, stacklevel=2) 
     
    7078    def __init__(self, name, _bwHack='UNIX'): 
    7179        self.name = name 
    7280        self._bwHack = _bwHack 
    73      
     81 
     82    def buildEndpoint(self): 
     83        return UNIXEndpoint(self.name) 
     84 
    7485    def __getitem__(self, index): 
    7586        warnings.warn("UNIXAddress.__getitem__ is deprecated.  Use attributes instead.", 
    7687                      category=DeprecationWarning, stacklevel=2)