[Twisted-Python] DNS SRV support for Connectors

Tommi Virtanen tv at twistedmatrix.com
Sun Dec 29 17:47:04 EST 2002


	Hi. I wanted to have SRV record (RFC 2782) support connecting
	to services. So here we go.

	The following patch adds a new Connector class that takes
	service (smtp, http, ldap, ..) and domain (example.com), looks
	for a DNS SRV record and falls back sanely if not found.

	itamar pointed out that this might be more suitable in
	twisted.names.client -- feel free to point out what is the
	"right" place.

	Please give feedback.

	The patch also fixes the output of SRV-using example; it used
	to not output the interesting bits.

	The patch is attached, and the latest version is at
	http://www.twistedmatrix.com/users/tv/connectTCPService.patch

-- 
:(){ :|:&};:
-------------- next part --------------
Index: twisted/internet/default.py
===================================================================
RCS file: /cvs/Twisted/twisted/internet/default.py,v
retrieving revision 1.53
diff -u -u -r1.53 default.py
--- twisted/internet/default.py	17 Dec 2002 02:31:59 -0000	1.53
+++ twisted/internet/default.py	29 Dec 2002 22:31:24 -0000
@@ -29,12 +29,15 @@
 import os
 import socket
 import sys
+import types
+import random
 
 from twisted.internet.interfaces import IReactorCore, IReactorTime, IReactorUNIX
 from twisted.internet.interfaces import IReactorTCP, IReactorUDP, IReactorSSL
 from twisted.internet.interfaces import IReactorProcess, IReactorFDSet, IReactorMulticast
 from twisted.internet import main, error, protocol, interfaces
 from twisted.internet import tcp, udp, task, defer
+from twisted.names import client
 
 
 from twisted.python import log, threadable, failure
@@ -142,6 +145,10 @@
 
     def __init__(self, reactor, host, port, factory, timeout, bindAddress):
         self.host = host
+        if isinstance(port, types.IntType):
+            port = port
+        else:
+            port = socket.getservbyname(port, 'tcp')
         self.port = port
         self.bindAddress = bindAddress
         BaseConnector.__init__(self, reactor, factory, timeout)
@@ -152,6 +159,107 @@
     def getDestination(self):
         return ('INET', self.host, self.port)
 
+class TCPServiceConnector(BaseConnector):
+    """See RFC2782."""
+    stopAfterDNS=0
+
+    def __init__(self, reactor, service, domain, factory, timeout, bindAddress):
+        self.service = service
+        self.domain = domain
+        self.bindAddress = bindAddress
+        BaseConnector.__init__(self, reactor, factory, timeout)
+
+        self.connector = None
+        self.servers = None
+        self.orderedServers = None # list of servers already used in this round
+
+    def connect(self):
+        if not self.servers:
+            d = client.theResolver.lookupService('_%s._tcp.%s' % (self.service, self.domain))
+            d.addCallback(self._cbGotServers)
+            d.addCallback(lambda x, self=self: self._reallyConnect())
+            d.addErrback(self.connectionFailed)
+        elif self.connector is None:
+            self._reallyConnect()
+        else:
+            self.connector.connect()
+
+    def _cbGotServers(self, answers):
+        if len(answers)==1 and answers[0].payload.target=='.':
+            # decidedly not available
+            raise error.DNSLookupError("Service %s not available for domain %s."
+                                       % (repr(self.service), repr(self.domain)))
+
+        self.servers = []
+        self.orderedServers = []
+        for a in answers:
+            self.orderedServers.append((a.payload.priority, a.payload.weight,
+                                        str(a.payload.target), a.payload.port))
+
+    def _serverCmp(self, a, b):
+        if a[0]!=b[0]:
+            return cmp(a[0], b[0])
+        else:
+            return cmp(a[1], b[1])
+
+    def pickServer(self):
+        assert self.servers is not None
+        assert self.orderedServers is not None
+
+        if not self.servers and not self.orderedServers:
+            # no SRV record, fall back..
+            return self.domain, self.service
+
+        if not self.servers and self.orderedServers:
+            # start new round
+            self.servers = self.orderedServers
+            self.orderedServers = []
+
+        assert self.servers
+
+        self.servers.sort(self._serverCmp)
+        minPriority=self.servers[0][0]
+
+        weightIndex = zip(xrange(len(self.servers)), [x[1] for x in self.servers if x[0]==minPriority])
+        weightSum = reduce(lambda x, y: (None, x[1]+y[1]), weightIndex, (None, 0))[1]
+        rand = random.randint(0, weightSum)
+
+        for index, weight in weightIndex:
+            weightSum -= weight
+            if weightSum <= 0:
+                chosen = self.servers[index]
+                del self.servers[index]
+                self.orderedServers.append(chosen)
+
+                p, w, host, port = chosen
+                print "HOST", host, port
+                return host, port
+
+        raise 'We really should never get here!'
+
+    def _reallyConnect(self):
+        if self.stopAfterDNS:
+            self.stopAfterDNS=0
+            return
+
+        self.host, self.port = self.pickServer()
+
+        # TODO connectSSL?
+        self.connector=self.reactor.connectTCP(self.host, self.port, self.factory)
+
+    def stopConnecting(self):
+        if self.connector:
+            self.connector.stopConnecting()
+        else:
+            self.stopAfterDNS=1
+
+    def disconnect(self):
+        if self.connector is not None:
+            self.connector.disconnect()
+
+    def getDestination(self):
+        assert self.connector
+        return self.connector.getDestination()
 
 class UNIXConnector(BaseConnector):
 
@@ -347,6 +455,13 @@
         """See twisted.internet.interfaces.IReactorTCP.connectTCP
         """
         c = TCPConnector(self, host, port, factory, timeout, bindAddress)
+        c.connect()
+        return c
+
+    def connectTCPService(self, service, domain, factory, timeout=30, bindAddress=None):
+        """See twisted.internet.interfaces.IReactorTCP.connectTCPService (TODO)
+        """
+        c = TCPServiceConnector(self, service, domain, factory, timeout, bindAddress)
         c.connect()
         return c
 
Index: doc/examples/dns-service.py
===================================================================
RCS file: /cvs/Twisted/doc/examples/dns-service.py,v
retrieving revision 1.1
diff -u -u -r1.1 dns-service.py
--- doc/examples/dns-service.py	23 Dec 2002 22:22:53 -0000	1.1
+++ doc/examples/dns-service.py	29 Dec 2002 22:31:24 -0000
@@ -10,7 +10,7 @@
     if not len(answers):
         print 'No answers'
     else:
-        print '\n'.join(map(str, answers))
+        print '\n'.join([str(x.payload) for x in answers])
     reactor.stop()
 
 def printFailure(arg):


More information about the Twisted-Python mailing list