Ticket #1442: endpoints2.diff

File endpoints2.diff, 13.6 KB (added by rwall, 10 years ago)

Dreid's branch, with regular factories, SSL support, and some tests.

  • twisted/test/test_endpoints.py

     
     1from twisted.trial import unittest
     2from twisted.internet import (defer, error, interfaces, reactor,
     3                              _sslverify as sslverify)
     4from twisted.internet.protocol import ClientFactory, Protocol, ServerFactory
     5from twisted.internet.endpoints import (TCPClientEndpoint, TCPServerEndpoint,
     6                                        UNIXClientEndpoint, UNIXServerEndpoint)
     7from twisted.test.test_sslverify import makeCertificate
     8
     9class ServerProtocol(Protocol):
     10    def connectionLost(self, *a):
     11        self.factory.onConnectionLost.callback(self)
     12
     13class ClientProtocol(Protocol):
     14    def connectionMade(self):
     15        self.factory.onConnectionMade.callback(self)
     16
     17    def connectionLost(self, *a):
     18        self.factory.onConnectionLost.callback(self)
     19
     20class MyServerFactory(ServerFactory):
     21    protocol = ServerProtocol
     22   
     23    def __init__(self):
     24        self.onConnectionLost = defer.Deferred()
     25
     26class MyClientFactory(ClientFactory):
     27    protocol = ClientProtocol
     28   
     29    def __init__(self):
     30        self.onConnectionMade = defer.Deferred()
     31        self.onConnectionLost = defer.Deferred()
     32
     33class PortAndConnectorCleanerUpper(unittest.TestCase):
     34    def setUp(self):
     35        self.listeningPorts = []
     36        self.clientConnections = []
     37   
     38    def tearDown(self):
     39        map(lambda p: p.stopListening(), self.listeningPorts)
     40        map(lambda c: c.disconnect(), self.clientConnections)
     41
     42class EndpointTestCaseMixin(object):
     43    def test_clientEndpointConnectSuccess(self):
     44        """Test that Endpoint can connect and returns a deferred who
     45        gets called back with a protocol instance.
     46        """
     47        p = self.createServer(MyServerFactory())
     48        addr = p.getHost()
     49        ep = self.createClientEndpoint(addr)
     50        d = ep.connect(reactor, MyClientFactory())
     51        self.assertTrue(isinstance(d, defer.Deferred))
     52        def onConnect(proto):
     53            self.assertTrue(interfaces.IProtocol.providedBy(proto))
     54            proto.transport.loseConnection()
     55        d.addCallback(onConnect)
     56        return d
     57   
     58    def test_clientEndpointConnectFailure(self):
     59        """Test that when a ClientEndpoint tries to connect to a none
     60        listening port that it gets a ConnectError failure.
     61        """
     62        p = self.createServer(MyServerFactory())
     63        addr = p.getHost()
     64        p.loseConnection()
     65
     66        ep = self.createClientEndpoint(addr)
     67        d = ep.connect(reactor, MyClientFactory())
     68        self.failUnlessFailure(d, error.ConnectError)
     69        return d
     70   
     71    def test_serverEndpointListenSuccess(self):
     72        """Test that ServerEndpoint can listen and returns a deferred that
     73        gets called back with a port instance.
     74        """
     75        ep = self.createServerEndpoint()
     76        d = ep.listen(reactor, MyServerFactory())
     77        self.assertTrue(isinstance(d, defer.Deferred))
     78        def onListen(port):
     79            self.assertTrue(interfaces.IListeningPort.providedBy(port))
     80            self.listeningPorts.append(port)
     81            return port.getHost()
     82        d.addCallback(onListen)
     83        def connectTo(addr):
     84            self.createClient(addr, MyClientFactory())
     85        d.addCallback(connectTo)
     86        return d
     87
     88    def test_serverEndpointListenFailure(self):
     89        """Test that if ServerEndpoint tries to listen on an already listening
     90        port, that a CannotListenError failure is errbacked.
     91        """
     92        p = self.createServer(MyServerFactory())
     93        addr = p.getHost()
     94        ep = self.createServerEndpoint(addr)
     95        d = ep.listen(reactor, MyServerFactory())
     96        self.failUnlessFailure(d, error.CannotListenError)
     97        return d
     98
     99class TCPEndpointsTestCase(PortAndConnectorCleanerUpper, EndpointTestCaseMixin):
     100    def createServer(self, factory):
     101        p = reactor.listenTCP(0, factory)
     102        self.listeningPorts.append(p)
     103        return p
     104
     105    def createClient(self, address, factory):
     106        c = reactor.connectTCP(address.host, address.port, factory)
     107        self.clientConnections.append(c)
     108        return c
     109
     110    def createClientEndpoint(self, address):
     111        return TCPClientEndpoint(address.host, address.port)
     112
     113    def createServerEndpoint(self, addr=None):
     114        arg = 0
     115        if addr:
     116            arg = addr.port
     117        return TCPServerEndpoint(arg)
     118
     119class SSLEndpointsTestCase(PortAndConnectorCleanerUpper, EndpointTestCaseMixin):
     120   
     121    def setUpClass(self):
     122        self.sKey, self.sCert = makeCertificate(
     123            O="Server Test Certificate",
     124            CN="server")
     125        self.cKey, self.cCert = makeCertificate(
     126            O="Client Test Certificate",
     127            CN="client")
     128        self.serverSSLContext = sslverify.OpenSSLCertificateOptions(privateKey=self.sKey, certificate=self.sCert, requireCertificate=False)
     129        self.clientSSLContext = sslverify.OpenSSLCertificateOptions(requireCertificate=False)
     130       
     131    def createServer(self, factory):
     132        p = reactor.listenSSL(0, factory, self.serverSSLContext)
     133        self.listeningPorts.append(p)
     134        return p
     135
     136    def createClient(self, address, factory):
     137        c = reactor.connectSSL(address.host, address.port, factory,
     138                               self.clientSSLContext)
     139        self.clientConnections.append(c)
     140        return c
     141
     142    def createClientEndpoint(self, address):
     143        return TCPClientEndpoint(address.host, address.port,
     144                                 sslContextFactory=self.clientSSLContext)
     145
     146    def createServerEndpoint(self, addr=None):
     147        arg = 0
     148        if addr:
     149            arg = addr.port
     150        return TCPServerEndpoint(arg, sslContextFactory=self.serverSSLContext)
     151       
     152
     153class UNIXEndpointsTestCase(PortAndConnectorCleanerUpper, EndpointTestCaseMixin):
     154    def createServer(self, factory):
     155        p = reactor.listenUNIX(self.mktemp(), factory)
     156        self.listeningPorts.append(p)
     157        return p
     158
     159    def createClient(self, address, factory):
     160        c = reactor.connectUNIX(address.name, factory)
     161        self.clientConnections.append(c)
     162        return c
     163
     164    def createClientEndpoint(self, address):
     165        return UNIXClientEndpoint(address.name)
     166
     167    def createServerEndpoint(self, addr=None):
     168        arg = self.mktemp()
     169        if addr:
     170            arg = addr.name
     171        return UNIXServerEndpoint(arg)
  • twisted/internet/endpoints.py

     
     1# -*- test-case-name: twisted.test.test_endpoints -*-
     2
     3from zope.interface import implements, providedBy, directlyProvides
     4
     5from twisted.internet import interfaces
     6from twisted.internet import defer, protocol
     7from twisted.internet.protocol import ClientFactory
     8from twisted.protocols.policies import ProtocolWrapper
     9
     10class _WrappingProtocol(ProtocolWrapper):
     11    # FIXME: we probably don't need to use policies.ProtocolWrapper
     12    # with a little work we can just set up the wrappedProtocols
     13    # transport correctly, instead of pretending to be it.
     14    def connectionMade(self):
     15        ProtocolWrapper.connectionMade(self)
     16        self.factory.deferred.callback(self.wrappedProtocol)
     17
     18class _WrappingFactory(ClientFactory):
     19    protocol = _WrappingProtocol
     20
     21    def __init__(self, wrappedFactory):
     22        self.wrappedFactory = wrappedFactory
     23        self.deferred = defer.Deferred()
     24
     25    def buildProtocol(self, addr):
     26        try:
     27            proto = self.wrappedFactory.buildProtocol(addr)
     28        except:
     29            self.deferred.errback()
     30        else:
     31            return self.protocol(self, proto)
     32
     33    def registerProtocol(self, proto):
     34        pass
     35
     36    def unregisterProtocol(self, proto):
     37        pass
     38
     39    def clientConnectionFailed(self, connector, reason):
     40        self.deferred.errback(reason)
     41
     42
     43class TCPClientEndpoint(object):
     44    implements(interfaces.IClientEndpoint)
     45
     46    def __init__(self, host, port, timeout=30, bindAddress=None,
     47                 sslContextFactory=None):
     48
     49        self.host = host
     50        self.port = port
     51        self.timeout = timeout
     52        self.bindAddress = bindAddress
     53        self.sslContextFactory = sslContextFactory
     54       
     55    def connect(self, reactor, clientFactory):
     56        wf = _WrappingFactory(clientFactory)
     57        extraConnectArgs = {}
     58        connectMethod = reactor.connectTCP
     59        if self.sslContextFactory:
     60            connectMethod = reactor.connectSSL
     61            extraConnectArgs["contextFactory"] = self.sslContextFactory
     62           
     63        d = defer.execute(connectMethod, self.host, self.port, wf,
     64                          timeout=self.timeout,
     65                          bindAddress=self.bindAddress, **extraConnectArgs)
     66
     67        d.addCallback(lambda _: wf.deferred)
     68
     69        return d
     70
     71class TCPServerEndpoint(object):
     72    implements(interfaces.IServerEndpoint)
     73
     74    def __init__(self, port=0, backlog=50, interface='',
     75                 sslContextFactory=None):
     76
     77        self.port = port
     78        self.backlog = backlog
     79        self.interface = interface
     80        self.sslContextFactory = sslContextFactory
     81
     82    def listen(self, reactor, serverFactory):
     83        wf = _WrappingFactory(serverFactory)
     84        extraListenArgs = {}
     85        listenMethod = reactor.listenTCP
     86        if self.sslContextFactory:
     87            listenMethod = reactor.listenSSL
     88            extraListenArgs["contextFactory"] = self.sslContextFactory
     89        return defer.execute(listenMethod, self.port, wf,
     90                             backlog=self.backlog,
     91                             interface=self.interface, **extraListenArgs)
     92
     93class UNIXClientEndpoint(object):
     94    implements(interfaces.IClientEndpoint)
     95
     96    def __init__(self, address, timeout=30, checkPID=0):
     97        self.address = address
     98        self.timeout = timeout
     99        self.checkPID= checkPID
     100
     101    def connect(self, reactor, clientFactory):
     102        wf = _WrappingFactory(clientFactory)
     103        d = defer.execute(reactor.connectUNIX, self.address, wf,
     104                          timeout=self.timeout,
     105                          checkPID=self.checkPID)
     106
     107        d.addCallback(lambda _: wf.deferred)
     108
     109        return d
     110
     111class UNIXServerEndpoint(object):
     112    implements(interfaces.IServerEndpoint)
     113
     114    def __init__(self, address, backlog=50, mode=0666, wantPID=0):
     115        self.address = address
     116        self.backlog = backlog
     117        self.mode = mode
     118        self.wantPID= wantPID
     119
     120    def listen(self, reactor, serverFactory):
     121        wf = _WrappingFactory(serverFactory)
     122        return defer.execute(reactor.listenUNIX, self.address, wf,
     123                                                backlog=self.backlog,
     124                                                mode=self.mode,
     125                                                wantPID=self.wantPID)
  • twisted/internet/interfaces.py

     
    12941294
    12951295    def leaveGroup(addr, interface=""):
    12961296        """Leave multicast group, return Deferred of success."""
     1297
     1298class IClientEndpoint(Interface):
     1299    """Object that represents a remote endpoint that we wish to connect to.
     1300    """
     1301    def connect(reactor, clientFactory):
     1302        """
     1303        @param reactor: The reactor
     1304        @param clientFactory: A provider of L{IProtocolFactory}
     1305        @return: A L{Deferred} that results in an L{IProtocol} upon successful
     1306        connection otherwise a L{ConnectError}
     1307        """
     1308
     1309class IServerEndpoint(Interface):
     1310    """Object representing an endpoint where we will listen for connections.
     1311    """
     1312
     1313    def listen(callable):
     1314        """
     1315        @param reactor: The reactor
     1316        @param serverFactory: A provider of L{IProtocolFactory}
     1317        @return: A L{Deferred} that results in an L{IListeningPort} or an
     1318        L{CannotListenError}
     1319        """
  • twisted/internet/address.py

     
    55"""Address objects for network connections."""
    66
    77import warnings, os
     8
    89from zope.interface import implements
     10
    911from twisted.internet.interfaces import IAddress
    1012
    11 
    1213class IPv4Address(object):
    1314    """
    1415    Object representing an IPv4 socket endpoint.
     
    3435        self.port = port
    3536        self._bwHack = _bwHack
    3637
     38    def buildClientEndpoint(self):
     39        return endpoints.TCPClientEndpoint(self.host, self.port)
     40
     41    def buildServerEndpoint(self):
     42        return endpoints.TCPServerEndpoint(self.host, self.port)
     43
    3744    def __getitem__(self, index):
    3845        warnings.warn("IPv4Address.__getitem__ is deprecated.  Use attributes instead.",
    3946                      category=DeprecationWarning, stacklevel=2)
     
    7077    def __init__(self, name, _bwHack='UNIX'):
    7178        self.name = name
    7279        self._bwHack = _bwHack
    73    
     80
     81    def buildClientEndpoint(self):
     82        return endpoints.UNIXClientEndpoint(self.name)
     83
     84    def buildServerEndpoint(self):
     85        return endpoints.UNIXServerEndpoint(self.name)
     86
    7487    def __getitem__(self, index):
    7588        warnings.warn("UNIXAddress.__getitem__ is deprecated.  Use attributes instead.",
    7689                      category=DeprecationWarning, stacklevel=2)