Ticket #1442: endpoints2.diff

File endpoints2.diff, 13.6 KB (added by rwall, 8 years ago)

Dreid's branch, with regular factories, SSL support, and some tests.

  • twisted/test/test_endpoints.py

     
     1from twisted.trial import unittest 
     2from twisted.internet import (defer, error, interfaces, reactor,  
     3                              _sslverify as sslverify) 
     4from twisted.internet.protocol import ClientFactory, Protocol, ServerFactory 
     5from twisted.internet.endpoints import (TCPClientEndpoint, TCPServerEndpoint,  
     6                                        UNIXClientEndpoint, UNIXServerEndpoint) 
     7from twisted.test.test_sslverify import makeCertificate 
     8 
     9class ServerProtocol(Protocol): 
     10    def connectionLost(self, *a): 
     11        self.factory.onConnectionLost.callback(self) 
     12 
     13class ClientProtocol(Protocol): 
     14    def connectionMade(self): 
     15        self.factory.onConnectionMade.callback(self) 
     16 
     17    def connectionLost(self, *a): 
     18        self.factory.onConnectionLost.callback(self) 
     19 
     20class MyServerFactory(ServerFactory): 
     21    protocol = ServerProtocol 
     22     
     23    def __init__(self): 
     24        self.onConnectionLost = defer.Deferred() 
     25 
     26class MyClientFactory(ClientFactory): 
     27    protocol = ClientProtocol 
     28     
     29    def __init__(self): 
     30        self.onConnectionMade = defer.Deferred() 
     31        self.onConnectionLost = defer.Deferred() 
     32 
     33class PortAndConnectorCleanerUpper(unittest.TestCase): 
     34    def setUp(self): 
     35        self.listeningPorts = [] 
     36        self.clientConnections = [] 
     37     
     38    def tearDown(self): 
     39        map(lambda p: p.stopListening(), self.listeningPorts) 
     40        map(lambda c: c.disconnect(), self.clientConnections) 
     41 
     42class EndpointTestCaseMixin(object): 
     43    def test_clientEndpointConnectSuccess(self): 
     44        """Test that Endpoint can connect and returns a deferred who 
     45        gets called back with a protocol instance.  
     46        """ 
     47        p = self.createServer(MyServerFactory()) 
     48        addr = p.getHost() 
     49        ep = self.createClientEndpoint(addr) 
     50        d = ep.connect(reactor, MyClientFactory()) 
     51        self.assertTrue(isinstance(d, defer.Deferred)) 
     52        def onConnect(proto): 
     53            self.assertTrue(interfaces.IProtocol.providedBy(proto)) 
     54            proto.transport.loseConnection() 
     55        d.addCallback(onConnect) 
     56        return d 
     57     
     58    def test_clientEndpointConnectFailure(self): 
     59        """Test that when a ClientEndpoint tries to connect to a none  
     60        listening port that it gets a ConnectError failure. 
     61        """ 
     62        p = self.createServer(MyServerFactory()) 
     63        addr = p.getHost() 
     64        p.loseConnection() 
     65 
     66        ep = self.createClientEndpoint(addr) 
     67        d = ep.connect(reactor, MyClientFactory()) 
     68        self.failUnlessFailure(d, error.ConnectError) 
     69        return d 
     70     
     71    def test_serverEndpointListenSuccess(self): 
     72        """Test that ServerEndpoint can listen and returns a deferred that 
     73        gets called back with a port instance.  
     74        """ 
     75        ep = self.createServerEndpoint() 
     76        d = ep.listen(reactor, MyServerFactory()) 
     77        self.assertTrue(isinstance(d, defer.Deferred)) 
     78        def onListen(port): 
     79            self.assertTrue(interfaces.IListeningPort.providedBy(port)) 
     80            self.listeningPorts.append(port) 
     81            return port.getHost() 
     82        d.addCallback(onListen) 
     83        def connectTo(addr): 
     84            self.createClient(addr, MyClientFactory()) 
     85        d.addCallback(connectTo) 
     86        return d 
     87 
     88    def test_serverEndpointListenFailure(self): 
     89        """Test that if ServerEndpoint tries to listen on an already listening  
     90        port, that a CannotListenError failure is errbacked.  
     91        """ 
     92        p = self.createServer(MyServerFactory()) 
     93        addr = p.getHost() 
     94        ep = self.createServerEndpoint(addr) 
     95        d = ep.listen(reactor, MyServerFactory()) 
     96        self.failUnlessFailure(d, error.CannotListenError) 
     97        return d 
     98 
     99class TCPEndpointsTestCase(PortAndConnectorCleanerUpper, EndpointTestCaseMixin): 
     100    def createServer(self, factory): 
     101        p = reactor.listenTCP(0, factory) 
     102        self.listeningPorts.append(p) 
     103        return p 
     104 
     105    def createClient(self, address, factory): 
     106        c = reactor.connectTCP(address.host, address.port, factory) 
     107        self.clientConnections.append(c) 
     108        return c 
     109 
     110    def createClientEndpoint(self, address): 
     111        return TCPClientEndpoint(address.host, address.port) 
     112 
     113    def createServerEndpoint(self, addr=None): 
     114        arg = 0 
     115        if addr: 
     116            arg = addr.port 
     117        return TCPServerEndpoint(arg) 
     118 
     119class SSLEndpointsTestCase(PortAndConnectorCleanerUpper, EndpointTestCaseMixin): 
     120     
     121    def setUpClass(self): 
     122        self.sKey, self.sCert = makeCertificate( 
     123            O="Server Test Certificate", 
     124            CN="server") 
     125        self.cKey, self.cCert = makeCertificate( 
     126            O="Client Test Certificate", 
     127            CN="client") 
     128        self.serverSSLContext = sslverify.OpenSSLCertificateOptions(privateKey=self.sKey, certificate=self.sCert, requireCertificate=False) 
     129        self.clientSSLContext = sslverify.OpenSSLCertificateOptions(requireCertificate=False) 
     130         
     131    def createServer(self, factory): 
     132        p = reactor.listenSSL(0, factory, self.serverSSLContext) 
     133        self.listeningPorts.append(p) 
     134        return p 
     135 
     136    def createClient(self, address, factory): 
     137        c = reactor.connectSSL(address.host, address.port, factory,  
     138                               self.clientSSLContext) 
     139        self.clientConnections.append(c) 
     140        return c 
     141 
     142    def createClientEndpoint(self, address): 
     143        return TCPClientEndpoint(address.host, address.port,  
     144                                 sslContextFactory=self.clientSSLContext) 
     145 
     146    def createServerEndpoint(self, addr=None): 
     147        arg = 0 
     148        if addr: 
     149            arg = addr.port 
     150        return TCPServerEndpoint(arg, sslContextFactory=self.serverSSLContext) 
     151         
     152 
     153class UNIXEndpointsTestCase(PortAndConnectorCleanerUpper, EndpointTestCaseMixin): 
     154    def createServer(self, factory): 
     155        p = reactor.listenUNIX(self.mktemp(), factory) 
     156        self.listeningPorts.append(p) 
     157        return p 
     158 
     159    def createClient(self, address, factory): 
     160        c = reactor.connectUNIX(address.name, factory) 
     161        self.clientConnections.append(c) 
     162        return c 
     163 
     164    def createClientEndpoint(self, address): 
     165        return UNIXClientEndpoint(address.name) 
     166 
     167    def createServerEndpoint(self, addr=None): 
     168        arg = self.mktemp() 
     169        if addr: 
     170            arg = addr.name 
     171        return UNIXServerEndpoint(arg) 
  • twisted/internet/endpoints.py

     
     1# -*- test-case-name: twisted.test.test_endpoints -*- 
     2 
     3from zope.interface import implements, providedBy, directlyProvides 
     4 
     5from twisted.internet import interfaces 
     6from twisted.internet import defer, protocol 
     7from twisted.internet.protocol import ClientFactory 
     8from twisted.protocols.policies import ProtocolWrapper 
     9 
     10class _WrappingProtocol(ProtocolWrapper): 
     11    # FIXME: we probably don't need to use policies.ProtocolWrapper 
     12    # with a little work we can just set up the wrappedProtocols 
     13    # transport correctly, instead of pretending to be it. 
     14    def connectionMade(self): 
     15        ProtocolWrapper.connectionMade(self) 
     16        self.factory.deferred.callback(self.wrappedProtocol) 
     17 
     18class _WrappingFactory(ClientFactory): 
     19    protocol = _WrappingProtocol 
     20 
     21    def __init__(self, wrappedFactory): 
     22        self.wrappedFactory = wrappedFactory 
     23        self.deferred = defer.Deferred() 
     24 
     25    def buildProtocol(self, addr): 
     26        try: 
     27            proto = self.wrappedFactory.buildProtocol(addr) 
     28        except: 
     29            self.deferred.errback() 
     30        else: 
     31            return self.protocol(self, proto) 
     32 
     33    def registerProtocol(self, proto): 
     34        pass 
     35 
     36    def unregisterProtocol(self, proto): 
     37        pass 
     38 
     39    def clientConnectionFailed(self, connector, reason): 
     40        self.deferred.errback(reason) 
     41 
     42 
     43class TCPClientEndpoint(object): 
     44    implements(interfaces.IClientEndpoint) 
     45 
     46    def __init__(self, host, port, timeout=30, bindAddress=None,  
     47                 sslContextFactory=None): 
     48 
     49        self.host = host 
     50        self.port = port 
     51        self.timeout = timeout 
     52        self.bindAddress = bindAddress 
     53        self.sslContextFactory = sslContextFactory 
     54         
     55    def connect(self, reactor, clientFactory): 
     56        wf = _WrappingFactory(clientFactory) 
     57        extraConnectArgs = {} 
     58        connectMethod = reactor.connectTCP 
     59        if self.sslContextFactory: 
     60            connectMethod = reactor.connectSSL 
     61            extraConnectArgs["contextFactory"] = self.sslContextFactory 
     62             
     63        d = defer.execute(connectMethod, self.host, self.port, wf, 
     64                          timeout=self.timeout, 
     65                          bindAddress=self.bindAddress, **extraConnectArgs) 
     66 
     67        d.addCallback(lambda _: wf.deferred) 
     68 
     69        return d 
     70 
     71class TCPServerEndpoint(object): 
     72    implements(interfaces.IServerEndpoint) 
     73 
     74    def __init__(self, port=0, backlog=50, interface='',  
     75                 sslContextFactory=None): 
     76 
     77        self.port = port 
     78        self.backlog = backlog 
     79        self.interface = interface 
     80        self.sslContextFactory = sslContextFactory 
     81 
     82    def listen(self, reactor, serverFactory): 
     83        wf = _WrappingFactory(serverFactory) 
     84        extraListenArgs = {} 
     85        listenMethod = reactor.listenTCP 
     86        if self.sslContextFactory: 
     87            listenMethod = reactor.listenSSL 
     88            extraListenArgs["contextFactory"] = self.sslContextFactory 
     89        return defer.execute(listenMethod, self.port, wf, 
     90                             backlog=self.backlog, 
     91                             interface=self.interface, **extraListenArgs) 
     92 
     93class UNIXClientEndpoint(object): 
     94    implements(interfaces.IClientEndpoint) 
     95 
     96    def __init__(self, address, timeout=30, checkPID=0): 
     97        self.address = address 
     98        self.timeout = timeout 
     99        self.checkPID= checkPID 
     100 
     101    def connect(self, reactor, clientFactory): 
     102        wf = _WrappingFactory(clientFactory) 
     103        d = defer.execute(reactor.connectUNIX, self.address, wf, 
     104                          timeout=self.timeout, 
     105                          checkPID=self.checkPID) 
     106 
     107        d.addCallback(lambda _: wf.deferred) 
     108 
     109        return d 
     110 
     111class UNIXServerEndpoint(object): 
     112    implements(interfaces.IServerEndpoint) 
     113 
     114    def __init__(self, address, backlog=50, mode=0666, wantPID=0): 
     115        self.address = address 
     116        self.backlog = backlog 
     117        self.mode = mode 
     118        self.wantPID= wantPID 
     119 
     120    def listen(self, reactor, serverFactory): 
     121        wf = _WrappingFactory(serverFactory) 
     122        return defer.execute(reactor.listenUNIX, self.address, wf, 
     123                                                backlog=self.backlog, 
     124                                                mode=self.mode, 
     125                                                wantPID=self.wantPID) 
  • twisted/internet/interfaces.py

     
    12941294 
    12951295    def leaveGroup(addr, interface=""): 
    12961296        """Leave multicast group, return Deferred of success.""" 
     1297 
     1298class IClientEndpoint(Interface): 
     1299    """Object that represents a remote endpoint that we wish to connect to. 
     1300    """ 
     1301    def connect(reactor, clientFactory): 
     1302        """ 
     1303        @param reactor: The reactor 
     1304        @param clientFactory: A provider of L{IProtocolFactory} 
     1305        @return: A L{Deferred} that results in an L{IProtocol} upon successful 
     1306        connection otherwise a L{ConnectError} 
     1307        """ 
     1308 
     1309class IServerEndpoint(Interface): 
     1310    """Object representing an endpoint where we will listen for connections. 
     1311    """ 
     1312 
     1313    def listen(callable): 
     1314        """ 
     1315        @param reactor: The reactor 
     1316        @param serverFactory: A provider of L{IProtocolFactory} 
     1317        @return: A L{Deferred} that results in an L{IListeningPort} or an  
     1318        L{CannotListenError} 
     1319        """ 
  • twisted/internet/address.py

     
    55"""Address objects for network connections.""" 
    66 
    77import warnings, os 
     8 
    89from zope.interface import implements 
     10 
    911from twisted.internet.interfaces import IAddress 
    1012 
    11  
    1213class IPv4Address(object): 
    1314    """ 
    1415    Object representing an IPv4 socket endpoint. 
     
    3435        self.port = port 
    3536        self._bwHack = _bwHack 
    3637 
     38    def buildClientEndpoint(self): 
     39        return endpoints.TCPClientEndpoint(self.host, self.port) 
     40 
     41    def buildServerEndpoint(self): 
     42        return endpoints.TCPServerEndpoint(self.host, self.port) 
     43 
    3744    def __getitem__(self, index): 
    3845        warnings.warn("IPv4Address.__getitem__ is deprecated.  Use attributes instead.", 
    3946                      category=DeprecationWarning, stacklevel=2) 
     
    7077    def __init__(self, name, _bwHack='UNIX'): 
    7178        self.name = name 
    7279        self._bwHack = _bwHack 
    73      
     80 
     81    def buildClientEndpoint(self): 
     82        return endpoints.UNIXClientEndpoint(self.name) 
     83 
     84    def buildServerEndpoint(self): 
     85        return endpoints.UNIXServerEndpoint(self.name) 
     86 
    7487    def __getitem__(self, index): 
    7588        warnings.warn("UNIXAddress.__getitem__ is deprecated.  Use attributes instead.", 
    7689                      category=DeprecationWarning, stacklevel=2)