Ticket #1442: endpoints3.diff
| File endpoints3.diff, 14.6 KB (added by rwall, 7 years ago) |
|---|
-
twisted/test/test_endpoints.py
1 from twisted.trial import unittest 2 from twisted.internet import (defer, error, interfaces, reactor, 3 _sslverify as sslverify) 4 from twisted.internet.address import IPv4Address, UNIXAddress 5 from twisted.internet.protocol import ClientFactory, Protocol, ServerFactory 6 from twisted.internet.endpoints import (TCPEndpoint, UNIXEndpoint) 7 from twisted.test.test_sslverify import makeCertificate 8 9 class ServerProtocol(Protocol): 10 def connectionMade(self): 11 self.factory.onConnectionMade.callback(self) 12 13 def connectionLost(self, *a): 14 self.factory.onConnectionLost.callback(self) 15 16 class ClientProtocol(Protocol): 17 def connectionMade(self): 18 self.factory.onConnectionMade.callback(self) 19 20 def connectionLost(self, *a): 21 self.factory.onConnectionLost.callback(self) 22 23 class MyServerFactory(ServerFactory): 24 protocol = ServerProtocol 25 26 def __init__(self): 27 self.onConnectionMade = defer.Deferred() 28 self.onConnectionLost = defer.Deferred() 29 30 class MyClientFactory(ClientFactory): 31 protocol = ClientProtocol 32 33 def __init__(self): 34 self.onConnectionMade = defer.Deferred() 35 self.onConnectionLost = defer.Deferred() 36 37 class PortAndConnectorCleanerUpper(unittest.TestCase): 38 def setUp(self): 39 self.listeningPorts = [] 40 self.clientConnections = [] 41 42 def tearDown(self): 43 map(lambda p: p.stopListening(), self.listeningPorts) 44 map(lambda c: c.disconnect(), self.clientConnections) 45 46 class EndpointTestCaseMixin(object): 47 def test_EndpointConnectSuccess(self): 48 """Test that Endpoint can connect and returns a deferred who 49 gets called back with a protocol instance. 50 """ 51 sf = MyServerFactory() 52 p = self.createServer(sf) 53 addr = p.getHost() 54 ep = self.createEndpoint(addr) 55 cf = MyClientFactory() 56 d = ep.connect(reactor, cf) 57 self.assertTrue(isinstance(d, defer.Deferred)) 58 def onConnectSuccess(proto): 59 self.assertTrue(interfaces.IProtocol.providedBy(proto)) 60 proto.transport.loseConnection() 61 d.addCallback(onConnectSuccess) 62 return defer.gatherResults([sf.onConnectionMade, cf.onConnectionMade, d]) 63 64 def test_EndpointConnectFailure(self): 65 """Test that if an Endpoint tries to connect to a none 66 listening port that it gets a ConnectError failure. 67 """ 68 p = self.createServer(MyServerFactory()) 69 addr = p.getHost() 70 p.loseConnection() 71 72 ep = self.createEndpoint(addr) 73 d = ep.connect(reactor, MyClientFactory()) 74 self.failUnlessFailure(d, error.ConnectError) 75 return d 76 77 def test_EndpointListenSuccess(self): 78 """Test that Endpoint can listen and returns a deferred that 79 gets called back with a port instance. 80 """ 81 ep = self.createEndpoint() 82 sf = MyServerFactory() 83 d = ep.listen(reactor, sf) 84 self.assertTrue(isinstance(d, defer.Deferred)) 85 def onListenSuccess(port): 86 self.assertTrue(interfaces.IListeningPort.providedBy(port)) 87 self.listeningPorts.append(port) 88 return port.getHost() 89 d.addCallback(onListenSuccess) 90 def connectTo(addr): 91 self.createClient(addr, MyClientFactory()) 92 d.addCallback(connectTo) 93 return defer.gatherResults([sf.onConnectionMade, d]) 94 95 def test_EndpointListenFailure(self): 96 """Test that if Endpoint tries to listen on an already listening 97 port, that a CannotListenError failure is errbacked. 98 """ 99 p = self.createServer(MyServerFactory()) 100 addr = p.getHost() 101 ep = self.createEndpoint(addr) 102 d = ep.listen(reactor, MyServerFactory()) 103 self.failUnlessFailure(d, error.CannotListenError) 104 return d 105 106 class TCPEndpointsTestCase(PortAndConnectorCleanerUpper, EndpointTestCaseMixin): 107 def createServer(self, factory): 108 p = reactor.listenTCP(0, factory) 109 self.listeningPorts.append(p) 110 return p 111 112 def createClient(self, address, factory): 113 c = reactor.connectTCP(address.host, address.port, factory) 114 self.clientConnections.append(c) 115 return c 116 117 def createEndpoint(self, address=None): 118 if not address: 119 address = IPv4Address("TCP", "localhost", 0) 120 return TCPEndpoint(address.host, address.port) 121 122 class SSLEndpointsTestCase(PortAndConnectorCleanerUpper, EndpointTestCaseMixin): 123 124 def setUpClass(self): 125 self.sKey, self.sCert = makeCertificate( 126 O="Server Test Certificate", 127 CN="server") 128 self.cKey, self.cCert = makeCertificate( 129 O="Client Test Certificate", 130 CN="client") 131 self.serverSSLContext = sslverify.OpenSSLCertificateOptions(privateKey=self.sKey, certificate=self.sCert, requireCertificate=False) 132 self.clientSSLContext = sslverify.OpenSSLCertificateOptions(requireCertificate=False) 133 134 def createServer(self, factory): 135 p = reactor.listenSSL(0, factory, self.serverSSLContext) 136 self.listeningPorts.append(p) 137 return p 138 139 def createClient(self, address, factory): 140 c = reactor.connectSSL(address.host, address.port, factory, 141 self.clientSSLContext) 142 self.clientConnections.append(c) 143 return c 144 145 def createEndpoint(self, address=None): 146 if not address: 147 address = IPv4Address("TCP", "localhost", 0) 148 return TCPEndpoint(address.host, address.port, 149 sslContextFactory=self.clientSSLContext) 150 151 152 class UNIXEndpointsTestCase(PortAndConnectorCleanerUpper, EndpointTestCaseMixin): 153 def createServer(self, factory): 154 p = reactor.listenUNIX(self.mktemp(), factory) 155 self.listeningPorts.append(p) 156 return p 157 158 def createClient(self, address, factory): 159 c = reactor.connectUNIX(address.name, factory) 160 self.clientConnections.append(c) 161 return c 162 163 def createEndpoint(self, address=None): 164 if not address: 165 address = UNIXAddress(self.mktemp()) 166 return UNIXEndpoint(address.name) -
twisted/internet/endpoints.py
1 # -*- test-case-name: twisted.test.test_endpoints -*- 2 3 from zope.interface import implements, providedBy, directlyProvides 4 5 from twisted.internet import interfaces 6 from twisted.internet import defer, protocol 7 from twisted.internet.protocol import ClientFactory, Protocol 8 9 class _WrappingProtocol(Protocol): 10 """I wrap another protocol in order to notify my user when a connection has 11 been made. 12 """ 13 def __init__(self, factory, wrappedProtocol): 14 self.factory = factory 15 self.wrappedProtocol = wrappedProtocol 16 17 def connectionMade(self): 18 """XXX: As soon as I am connected, I connect my wrappedProtocol, giving 19 it my transport. Is it okay for a transport to be associated with more 20 than one protocol? Transport calls dataReceived on me and I in turn call 21 dataReceived on my wrappedProtocol. The wrappedProtocol may call 22 transport.write or transport.loseConnection etc 23 """ 24 25 self.wrappedProtocol.makeConnection(self.transport) 26 self.factory.deferred.callback(self.wrappedProtocol) 27 28 def dataReceived(self, data): 29 return self.wrappedProtocol.dataReceived(data) 30 31 def connectionLost(self, reason): 32 return self.wrappedProtocol.connectionLost(reason) 33 34 class _WrappingFactory(ClientFactory): 35 protocol = _WrappingProtocol 36 37 def __init__(self, wrappedFactory): 38 self.wrappedFactory = wrappedFactory 39 self.deferred = defer.Deferred() 40 41 def buildProtocol(self, addr): 42 try: 43 proto = self.wrappedFactory.buildProtocol(addr) 44 except: 45 self.deferred.errback() 46 else: 47 return self.protocol(self, proto) 48 49 def clientConnectionFailed(self, connector, reason): 50 self.deferred.errback(reason) 51 52 53 class TCPEndpoint(object): 54 implements(interfaces.IClientEndpoint, interfaces.IServerEndpoint) 55 56 def __init__(self, host, port, connectArgs={}, listenArgs={}, 57 sslContextFactory=None): 58 """ 59 @param host: A hostname, used only when connecting 60 @param port: The port number, used both when connecting and listening 61 @param connectArgs: An optional dict of keyword args that will be passed 62 to L{twisted.internet.interfaces.IReactorTCP.connectTCP} 63 @param listenArgs: An optional dict of keyword args that will be passed 64 to L{twisted.internet.interfaces.IReactorTCP.listenTCP} 65 @param sslContextFactory: An optional instance of 66 L{twisted.internet._sslverify.OpenSSLCertificateOptions}. If given, it 67 makes L{connect} and L{listen} use the corresponding methods from 68 L{twisted.internet.interfaces.IReactorSSL} 69 """ 70 self.host = host 71 self.port = port 72 self.connectArgs = dict(timeout=30, bindAddress=None) 73 self.connectArgs.update(connectArgs) 74 self.listenArgs = dict(backlog=50, interface='') 75 self.listenArgs.update(listenArgs) 76 self.sslContextFactory = sslContextFactory 77 78 def connect(self, reactor, clientFactory): 79 wf = _WrappingFactory(clientFactory) 80 connectArgs = self.connectArgs 81 connectMethod = reactor.connectTCP 82 if self.sslContextFactory: 83 connectMethod = reactor.connectSSL 84 connectArgs["contextFactory"] = self.sslContextFactory 85 86 d = defer.execute(connectMethod, self.host, self.port, wf, **connectArgs) 87 88 d.addCallback(lambda _: wf.deferred) 89 90 return d 91 92 def listen(self, reactor, serverFactory): 93 wf = _WrappingFactory(serverFactory) 94 listenArgs = self.listenArgs 95 listenMethod = reactor.listenTCP 96 if self.sslContextFactory: 97 listenMethod = reactor.listenSSL 98 listenArgs["contextFactory"] = self.sslContextFactory 99 return defer.execute(listenMethod, self.port, wf, **listenArgs) 100 101 class UNIXEndpoint(object): 102 implements(interfaces.IClientEndpoint, interfaces.IServerEndpoint) 103 104 def __init__(self, address, connectArgs={}, listenArgs={}): 105 """ 106 @param address: The path to the Unix socket file, used both when 107 connecting and listening 108 @param connectArgs: A dict of keyword args that will be passed to 109 L{twisted.internet.interfaces.IReactorUNIX.connectUNIX} 110 @param listenArgs: A dict of keyword args that will be passed to 111 L{twisted.internet.interfaces.IReactorUNIX.listenUNIX} 112 """ 113 114 self.address = address 115 self.connectArgs = dict(timeout=30, checkPID=0) 116 self.connectArgs.update(connectArgs) 117 self.listenArgs = dict(backlog=50, mode=0666, wantPID=0) 118 self.listenArgs.update(listenArgs) 119 120 def connect(self, reactor, clientFactory): 121 wf = _WrappingFactory(clientFactory) 122 d = defer.execute(reactor.connectUNIX, self.address, wf, 123 **self.connectArgs) 124 125 d.addCallback(lambda _: wf.deferred) 126 127 return d 128 129 def listen(self, reactor, serverFactory): 130 wf = _WrappingFactory(serverFactory) 131 return defer.execute(reactor.listenUNIX, self.address, wf, 132 **self.listenArgs) -
twisted/internet/interfaces.py
18 18 19 19 Default implementations are in L{twisted.internet.address}. 20 20 """ 21 22 def buildEndpoint(): 23 """ 24 @return: an instance providing both L{IClientEndpoint} and L{IServerEndpoint} 25 """ 21 26 22 27 23 28 ### Reactor Interfaces … … 1294 1299 1295 1300 def leaveGroup(addr, interface=""): 1296 1301 """Leave multicast group, return Deferred of success.""" 1302 1303 class IClientEndpoint(Interface): 1304 """Object that represents a remote endpoint that we wish to connect to. 1305 """ 1306 def connect(reactor, clientFactory): 1307 """ 1308 @param reactor: The reactor 1309 @param clientFactory: A provider of L{IProtocolFactory} 1310 @return: A L{Deferred} that results in an L{IProtocol} upon successful 1311 connection otherwise a L{ConnectError} 1312 """ 1313 1314 class IServerEndpoint(Interface): 1315 """Object representing an endpoint where we will listen for connections. 1316 """ 1317 1318 def listen(callable): 1319 """ 1320 @param reactor: The reactor 1321 @param serverFactory: A provider of L{IProtocolFactory} 1322 @return: A L{Deferred} that results in an L{IListeningPort} or an 1323 L{CannotListenError} 1324 """ -
twisted/internet/address.py
5 5 """Address objects for network connections.""" 6 6 7 7 import warnings, os 8 8 9 from zope.interface import implements 10 9 11 from twisted.internet.interfaces import IAddress 12 from twisted.internet.endpoints import TCPEndpoint, UNIXEndpoint 10 13 11 12 14 class IPv4Address(object): 13 15 """ 14 16 Object representing an IPv4 socket endpoint. … … 34 36 self.port = port 35 37 self._bwHack = _bwHack 36 38 39 def buildEndpoint(self): 40 if self.type == "TCP": 41 return TCPEndpoint(self.host, self.port) 42 else: 43 raise NotImplementedError 44 37 45 def __getitem__(self, index): 38 46 warnings.warn("IPv4Address.__getitem__ is deprecated. Use attributes instead.", 39 47 category=DeprecationWarning, stacklevel=2) … … 70 78 def __init__(self, name, _bwHack='UNIX'): 71 79 self.name = name 72 80 self._bwHack = _bwHack 73 81 82 def buildEndpoint(self): 83 return UNIXEndpoint(self.name) 84 74 85 def __getitem__(self, index): 75 86 warnings.warn("UNIXAddress.__getitem__ is deprecated. Use attributes instead.", 76 87 category=DeprecationWarning, stacklevel=2)
