root / trunk / twisted / names / srvconnect.py

Revision 21159, 6.2 kB (checked in by ralphm, 2 years ago)

Make SRVConnector work with bad results and NXDOMAIN responses.

Author: ralphm
Reviewer: therve
Fixes #1908, #2777.

Line 
1 # -*- test-case-name: twisted.names.test.test_srvconnect -*-
2 # Copyright (c) 2001-2007 Twisted Matrix Laboratories.
3 # See LICENSE for details.
4
5 import random
6
7 from zope.interface import implements
8
9 from twisted.internet import error, interfaces
10
11 from twisted.names import client, dns
12 from twisted.names.error import DNSNameError
13
14 class _SRVConnector_ClientFactoryWrapper:
15     def __init__(self, connector, wrappedFactory):
16         self.__connector = connector
17         self.__wrappedFactory = wrappedFactory
18
19     def startedConnecting(self, connector):
20         self.__wrappedFactory.startedConnecting(self.__connector)
21
22     def clientConnectionFailed(self, connector, reason):
23         self.__connector.connectionFailed(reason)
24
25     def clientConnectionLost(self, connector, reason):
26         self.__connector.connectionLost(reason)
27
28     def __getattr__(self, key):
29         return getattr(self.__wrappedFactory, key)
30
31 class SRVConnector:
32     """A connector that looks up DNS SRV records. See RFC2782."""
33
34     implements(interfaces.IConnector)
35
36     stopAfterDNS=0
37
38     def __init__(self, reactor, service, domain, factory,
39                  protocol='tcp', connectFuncName='connectTCP',
40                  connectFuncArgs=(),
41                  connectFuncKwArgs={},
42                  ):
43         self.reactor = reactor
44         self.service = service
45         self.domain = domain
46         self.factory = factory
47
48         self.protocol = protocol
49         self.connectFuncName = connectFuncName
50         self.connectFuncArgs = connectFuncArgs
51         self.connectFuncKwArgs = connectFuncKwArgs
52
53         self.connector = None
54         self.servers = None
55         self.orderedServers = None # list of servers already used in this round
56
57     def connect(self):
58         """Start connection to remote server."""
59         self.factory.doStart()
60         self.factory.startedConnecting(self)
61
62         if not self.servers:
63             if self.domain is None:
64                 self.connectionFailed(error.DNSLookupError("Domain is not defined."))
65                 return
66             d = client.lookupService('_%s._%s.%s' % (self.service,
67                                                      self.protocol,
68                                                      self.domain))
69             d.addCallbacks(self._cbGotServers, self._ebGotServers)
70             d.addCallback(lambda x, self=self: self._reallyConnect())
71             d.addErrback(self.connectionFailed)
72         elif self.connector is None:
73             self._reallyConnect()
74         else:
75             self.connector.connect()
76
77     def _ebGotServers(self, failure):
78         failure.trap(DNSNameError)
79
80         # Some DNS servers reply with NXDOMAIN when in fact there are
81         # just no SRV records for that domain. Act as if we just got an
82         # empty response and use fallback.
83
84         self.servers = []
85         self.orderedServers = []
86
87     def _cbGotServers(self, (answers, auth, add)):
88         if len(answers) == 1 and answers[0].type == dns.SRV \
89                              and answers[0].payload \
90                              and answers[0].payload.target == dns.Name('.'):
91             # decidedly not available
92             raise error.DNSLookupError("Service %s not available for domain %s."
93                                        % (repr(self.service), repr(self.domain)))
94
95         self.servers = []
96         self.orderedServers = []
97         for a in answers:
98             if a.type != dns.SRV or not a.payload:
99                 continue
100
101             self.orderedServers.append((a.payload.priority, a.payload.weight,
102                                         str(a.payload.target), a.payload.port))
103
104     def _serverCmp(self, a, b):
105         if a[0]!=b[0]:
106             return cmp(a[0], b[0])
107         else:
108             return cmp(a[1], b[1])
109
110     def pickServer(self):
111         assert self.servers is not None
112         assert self.orderedServers is not None
113
114         if not self.servers and not self.orderedServers:
115             # no SRV record, fall back..
116             return self.domain, self.service
117
118         if not self.servers and self.orderedServers:
119             # start new round
120             self.servers = self.orderedServers
121             self.orderedServers = []
122
123         assert self.servers
124
125         self.servers.sort(self._serverCmp)
126         minPriority=self.servers[0][0]
127
128         weightIndex = zip(xrange(len(self.servers)), [x[1] for x in self.servers
129                                                       if x[0]==minPriority])
130         weightSum = reduce(lambda x, y: (None, x[1]+y[1]), weightIndex, (None, 0))[1]
131         rand = random.randint(0, weightSum)
132
133         for index, weight in weightIndex:
134             weightSum -= weight
135             if weightSum <= 0:
136                 chosen = self.servers[index]
137                 del self.servers[index]
138                 self.orderedServers.append(chosen)
139
140                 p, w, host, port = chosen
141                 return host, port
142
143         raise RuntimeError, 'Impossible %s pickServer result.' % self.__class__.__name__
144
145     def _reallyConnect(self):
146         if self.stopAfterDNS:
147             self.stopAfterDNS=0
148             return
149
150         self.host, self.port = self.pickServer()
151         assert self.host is not None, 'Must have a host to connect to.'
152         assert self.port is not None, 'Must have a port to connect to.'
153
154         connectFunc = getattr(self.reactor, self.connectFuncName)
155         self.connector=connectFunc(
156             self.host, self.port,
157             _SRVConnector_ClientFactoryWrapper(self, self.factory),
158             *self.connectFuncArgs, **self.connectFuncKwArgs)
159
160     def stopConnecting(self):
161         """Stop attempting to connect."""
162         if self.connector:
163             self.connector.stopConnecting()
164         else:
165             self.stopAfterDNS=1
166
167     def disconnect(self):
168         """Disconnect whatever our are state is."""
169         if self.connector is not None:
170             self.connector.disconnect()
171         else:
172             self.stopConnecting()
173
174     def getDestination(self):
175         assert self.connector
176         return self.connector.getDestination()
177
178     def connectionFailed(self, reason):
179         self.factory.clientConnectionFailed(self, reason)
180         self.factory.doStop()
181
182     def connectionLost(self, reason):
183         self.factory.clientConnectionLost(self, reason)
184         self.factory.doStop()
185
Note: See TracBrowser for help on using the browser.