Ticket #1442: endpoints3.diff

File endpoints3.diff, 14.6 KB (added by rwall, 10 years ago)

This time with connect and listen united, a simpler wrapping protocol, and documented IAddress.buildEndpoint

  • twisted/test/test_endpoints.py

     
     1from twisted.trial import unittest
     2from twisted.internet import (defer, error, interfaces, reactor,
     3                              _sslverify as sslverify)
     4from twisted.internet.address import IPv4Address, UNIXAddress
     5from twisted.internet.protocol import ClientFactory, Protocol, ServerFactory
     6from twisted.internet.endpoints import (TCPEndpoint, UNIXEndpoint)
     7from twisted.test.test_sslverify import makeCertificate
     8
     9class ServerProtocol(Protocol):
     10    def connectionMade(self):
     11        self.factory.onConnectionMade.callback(self)
     12
     13    def connectionLost(self, *a):
     14        self.factory.onConnectionLost.callback(self)
     15
     16class ClientProtocol(Protocol):
     17    def connectionMade(self):
     18        self.factory.onConnectionMade.callback(self)
     19
     20    def connectionLost(self, *a):
     21        self.factory.onConnectionLost.callback(self)
     22
     23class MyServerFactory(ServerFactory):
     24    protocol = ServerProtocol
     25   
     26    def __init__(self):
     27        self.onConnectionMade = defer.Deferred()
     28        self.onConnectionLost = defer.Deferred()
     29
     30class MyClientFactory(ClientFactory):
     31    protocol = ClientProtocol
     32   
     33    def __init__(self):
     34        self.onConnectionMade = defer.Deferred()
     35        self.onConnectionLost = defer.Deferred()
     36
     37class PortAndConnectorCleanerUpper(unittest.TestCase):
     38    def setUp(self):
     39        self.listeningPorts = []
     40        self.clientConnections = []
     41   
     42    def tearDown(self):
     43        map(lambda p: p.stopListening(), self.listeningPorts)
     44        map(lambda c: c.disconnect(), self.clientConnections)
     45
     46class EndpointTestCaseMixin(object):
     47    def test_EndpointConnectSuccess(self):
     48        """Test that Endpoint can connect and returns a deferred who
     49        gets called back with a protocol instance.
     50        """
     51        sf = MyServerFactory()
     52        p = self.createServer(sf)
     53        addr = p.getHost()
     54        ep = self.createEndpoint(addr)
     55        cf = MyClientFactory()
     56        d = ep.connect(reactor, cf)
     57        self.assertTrue(isinstance(d, defer.Deferred))
     58        def onConnectSuccess(proto):
     59            self.assertTrue(interfaces.IProtocol.providedBy(proto))
     60            proto.transport.loseConnection()
     61        d.addCallback(onConnectSuccess)
     62        return defer.gatherResults([sf.onConnectionMade, cf.onConnectionMade, d])
     63       
     64    def test_EndpointConnectFailure(self):
     65        """Test that if an Endpoint tries to connect to a none
     66        listening port that it gets a ConnectError failure.
     67        """
     68        p = self.createServer(MyServerFactory())
     69        addr = p.getHost()
     70        p.loseConnection()
     71
     72        ep = self.createEndpoint(addr)
     73        d = ep.connect(reactor, MyClientFactory())
     74        self.failUnlessFailure(d, error.ConnectError)
     75        return d
     76   
     77    def test_EndpointListenSuccess(self):
     78        """Test that Endpoint can listen and returns a deferred that
     79        gets called back with a port instance.
     80        """
     81        ep = self.createEndpoint()
     82        sf = MyServerFactory()
     83        d = ep.listen(reactor, sf)
     84        self.assertTrue(isinstance(d, defer.Deferred))
     85        def onListenSuccess(port):
     86            self.assertTrue(interfaces.IListeningPort.providedBy(port))
     87            self.listeningPorts.append(port)
     88            return port.getHost()
     89        d.addCallback(onListenSuccess)
     90        def connectTo(addr):
     91            self.createClient(addr, MyClientFactory())
     92        d.addCallback(connectTo)
     93        return defer.gatherResults([sf.onConnectionMade, d])
     94
     95    def test_EndpointListenFailure(self):
     96        """Test that if Endpoint tries to listen on an already listening
     97        port, that a CannotListenError failure is errbacked.
     98        """
     99        p = self.createServer(MyServerFactory())
     100        addr = p.getHost()
     101        ep = self.createEndpoint(addr)
     102        d = ep.listen(reactor, MyServerFactory())
     103        self.failUnlessFailure(d, error.CannotListenError)
     104        return d
     105
     106class TCPEndpointsTestCase(PortAndConnectorCleanerUpper, EndpointTestCaseMixin):
     107    def createServer(self, factory):
     108        p = reactor.listenTCP(0, factory)
     109        self.listeningPorts.append(p)
     110        return p
     111
     112    def createClient(self, address, factory):
     113        c = reactor.connectTCP(address.host, address.port, factory)
     114        self.clientConnections.append(c)
     115        return c
     116
     117    def createEndpoint(self, address=None):
     118        if not address:
     119            address = IPv4Address("TCP", "localhost", 0)
     120        return TCPEndpoint(address.host, address.port)
     121
     122class SSLEndpointsTestCase(PortAndConnectorCleanerUpper, EndpointTestCaseMixin):
     123   
     124    def setUpClass(self):
     125        self.sKey, self.sCert = makeCertificate(
     126            O="Server Test Certificate",
     127            CN="server")
     128        self.cKey, self.cCert = makeCertificate(
     129            O="Client Test Certificate",
     130            CN="client")
     131        self.serverSSLContext = sslverify.OpenSSLCertificateOptions(privateKey=self.sKey, certificate=self.sCert, requireCertificate=False)
     132        self.clientSSLContext = sslverify.OpenSSLCertificateOptions(requireCertificate=False)
     133       
     134    def createServer(self, factory):
     135        p = reactor.listenSSL(0, factory, self.serverSSLContext)
     136        self.listeningPorts.append(p)
     137        return p
     138
     139    def createClient(self, address, factory):
     140        c = reactor.connectSSL(address.host, address.port, factory,
     141                               self.clientSSLContext)
     142        self.clientConnections.append(c)
     143        return c
     144
     145    def createEndpoint(self, address=None):
     146        if not address:
     147            address = IPv4Address("TCP", "localhost", 0)
     148        return TCPEndpoint(address.host, address.port,
     149                           sslContextFactory=self.clientSSLContext)
     150       
     151
     152class UNIXEndpointsTestCase(PortAndConnectorCleanerUpper, EndpointTestCaseMixin):
     153    def createServer(self, factory):
     154        p = reactor.listenUNIX(self.mktemp(), factory)
     155        self.listeningPorts.append(p)
     156        return p
     157
     158    def createClient(self, address, factory):
     159        c = reactor.connectUNIX(address.name, factory)
     160        self.clientConnections.append(c)
     161        return c
     162
     163    def createEndpoint(self, address=None):
     164        if not address:
     165            address = UNIXAddress(self.mktemp())
     166        return UNIXEndpoint(address.name)
  • twisted/internet/endpoints.py

     
     1# -*- test-case-name: twisted.test.test_endpoints -*-
     2
     3from zope.interface import implements, providedBy, directlyProvides
     4
     5from twisted.internet import interfaces
     6from twisted.internet import defer, protocol
     7from twisted.internet.protocol import ClientFactory, Protocol
     8
     9class _WrappingProtocol(Protocol):
     10    """I wrap another protocol in order to notify my user when a connection has
     11    been made.
     12    """
     13    def __init__(self, factory, wrappedProtocol):
     14        self.factory = factory
     15        self.wrappedProtocol = wrappedProtocol
     16       
     17    def connectionMade(self):
     18        """XXX: As soon as I am connected, I connect my wrappedProtocol, giving
     19        it my transport. Is it okay for a transport to be associated with more
     20        than one protocol? Transport calls dataReceived on me and I in turn call
     21        dataReceived on my wrappedProtocol. The wrappedProtocol may call
     22        transport.write or transport.loseConnection etc
     23        """
     24       
     25        self.wrappedProtocol.makeConnection(self.transport)
     26        self.factory.deferred.callback(self.wrappedProtocol)
     27       
     28    def dataReceived(self, data):
     29        return self.wrappedProtocol.dataReceived(data)
     30
     31    def connectionLost(self, reason):
     32        return self.wrappedProtocol.connectionLost(reason)
     33       
     34class _WrappingFactory(ClientFactory):
     35    protocol = _WrappingProtocol
     36
     37    def __init__(self, wrappedFactory):
     38        self.wrappedFactory = wrappedFactory
     39        self.deferred = defer.Deferred()
     40
     41    def buildProtocol(self, addr):
     42        try:
     43            proto = self.wrappedFactory.buildProtocol(addr)
     44        except:
     45            self.deferred.errback()
     46        else:
     47            return self.protocol(self, proto)
     48
     49    def clientConnectionFailed(self, connector, reason):
     50        self.deferred.errback(reason)
     51
     52
     53class TCPEndpoint(object):
     54    implements(interfaces.IClientEndpoint, interfaces.IServerEndpoint)
     55
     56    def __init__(self, host, port, connectArgs={}, listenArgs={},
     57                 sslContextFactory=None):
     58        """
     59        @param host: A hostname, used only when connecting
     60        @param port: The port number, used both when connecting and listening
     61        @param connectArgs: An optional dict of keyword args that will be passed
     62        to L{twisted.internet.interfaces.IReactorTCP.connectTCP}
     63        @param listenArgs: An optional dict of keyword args that will be passed
     64        to L{twisted.internet.interfaces.IReactorTCP.listenTCP}
     65        @param sslContextFactory: An optional instance of
     66        L{twisted.internet._sslverify.OpenSSLCertificateOptions}. If given, it
     67        makes L{connect} and L{listen} use the corresponding methods from
     68        L{twisted.internet.interfaces.IReactorSSL}
     69        """
     70        self.host = host
     71        self.port = port
     72        self.connectArgs = dict(timeout=30, bindAddress=None)
     73        self.connectArgs.update(connectArgs)
     74        self.listenArgs = dict(backlog=50, interface='')
     75        self.listenArgs.update(listenArgs)
     76        self.sslContextFactory = sslContextFactory
     77       
     78    def connect(self, reactor, clientFactory):
     79        wf = _WrappingFactory(clientFactory)
     80        connectArgs = self.connectArgs
     81        connectMethod = reactor.connectTCP
     82        if self.sslContextFactory:
     83            connectMethod = reactor.connectSSL
     84            connectArgs["contextFactory"] = self.sslContextFactory
     85           
     86        d = defer.execute(connectMethod, self.host, self.port, wf, **connectArgs)
     87
     88        d.addCallback(lambda _: wf.deferred)
     89
     90        return d
     91
     92    def listen(self, reactor, serverFactory):
     93        wf = _WrappingFactory(serverFactory)
     94        listenArgs = self.listenArgs
     95        listenMethod = reactor.listenTCP
     96        if self.sslContextFactory:
     97            listenMethod = reactor.listenSSL
     98            listenArgs["contextFactory"] = self.sslContextFactory
     99        return defer.execute(listenMethod, self.port, wf, **listenArgs)
     100
     101class UNIXEndpoint(object):
     102    implements(interfaces.IClientEndpoint, interfaces.IServerEndpoint)
     103
     104    def __init__(self, address, connectArgs={}, listenArgs={}):
     105        """
     106        @param address: The path to the Unix socket file, used both when
     107        connecting and listening
     108        @param connectArgs: A dict of keyword args that will be passed to
     109        L{twisted.internet.interfaces.IReactorUNIX.connectUNIX}
     110        @param listenArgs: A dict of keyword args that will be passed to
     111        L{twisted.internet.interfaces.IReactorUNIX.listenUNIX}
     112        """
     113
     114        self.address = address
     115        self.connectArgs = dict(timeout=30, checkPID=0)
     116        self.connectArgs.update(connectArgs)
     117        self.listenArgs = dict(backlog=50, mode=0666, wantPID=0)
     118        self.listenArgs.update(listenArgs)
     119
     120    def connect(self, reactor, clientFactory):
     121        wf = _WrappingFactory(clientFactory)
     122        d = defer.execute(reactor.connectUNIX, self.address, wf,
     123                          **self.connectArgs)
     124
     125        d.addCallback(lambda _: wf.deferred)
     126
     127        return d
     128
     129    def listen(self, reactor, serverFactory):
     130        wf = _WrappingFactory(serverFactory)
     131        return defer.execute(reactor.listenUNIX, self.address, wf,
     132                             **self.listenArgs)
  • twisted/internet/interfaces.py

     
    1818
    1919    Default implementations are in L{twisted.internet.address}.
    2020    """
     21   
     22    def buildEndpoint():
     23        """
     24        @return: an instance providing both L{IClientEndpoint} and L{IServerEndpoint}
     25        """
    2126
    2227
    2328### Reactor Interfaces
     
    12941299
    12951300    def leaveGroup(addr, interface=""):
    12961301        """Leave multicast group, return Deferred of success."""
     1302
     1303class IClientEndpoint(Interface):
     1304    """Object that represents a remote endpoint that we wish to connect to.
     1305    """
     1306    def connect(reactor, clientFactory):
     1307        """
     1308        @param reactor: The reactor
     1309        @param clientFactory: A provider of L{IProtocolFactory}
     1310        @return: A L{Deferred} that results in an L{IProtocol} upon successful
     1311        connection otherwise a L{ConnectError}
     1312        """
     1313
     1314class IServerEndpoint(Interface):
     1315    """Object representing an endpoint where we will listen for connections.
     1316    """
     1317
     1318    def listen(callable):
     1319        """
     1320        @param reactor: The reactor
     1321        @param serverFactory: A provider of L{IProtocolFactory}
     1322        @return: A L{Deferred} that results in an L{IListeningPort} or an
     1323        L{CannotListenError}
     1324        """
  • twisted/internet/address.py

     
    55"""Address objects for network connections."""
    66
    77import warnings, os
     8
    89from zope.interface import implements
     10
    911from twisted.internet.interfaces import IAddress
     12from twisted.internet.endpoints import TCPEndpoint, UNIXEndpoint
    1013
    11 
    1214class IPv4Address(object):
    1315    """
    1416    Object representing an IPv4 socket endpoint.
     
    3436        self.port = port
    3537        self._bwHack = _bwHack
    3638
     39    def buildEndpoint(self):
     40        if self.type == "TCP":
     41            return TCPEndpoint(self.host, self.port)
     42        else:
     43            raise NotImplementedError
     44
    3745    def __getitem__(self, index):
    3846        warnings.warn("IPv4Address.__getitem__ is deprecated.  Use attributes instead.",
    3947                      category=DeprecationWarning, stacklevel=2)
     
    7078    def __init__(self, name, _bwHack='UNIX'):
    7179        self.name = name
    7280        self._bwHack = _bwHack
    73    
     81
     82    def buildEndpoint(self):
     83        return UNIXEndpoint(self.name)
     84
    7485    def __getitem__(self, index):
    7586        warnings.warn("UNIXAddress.__getitem__ is deprecated.  Use attributes instead.",
    7687                      category=DeprecationWarning, stacklevel=2)