| 1 |
|
|---|
| 2 |
|
|---|
| 3 |
|
|---|
| 4 |
|
|---|
| 5 |
import random |
|---|
| 6 |
|
|---|
| 7 |
from zope.interface import implements |
|---|
| 8 |
|
|---|
| 9 |
from twisted.internet import error, interfaces |
|---|
| 10 |
|
|---|
| 11 |
from twisted.names import client, dns |
|---|
| 12 |
from twisted.names.error import DNSNameError |
|---|
| 13 |
|
|---|
| 14 |
class _SRVConnector_ClientFactoryWrapper: |
|---|
| 15 |
def __init__(self, connector, wrappedFactory): |
|---|
| 16 |
self.__connector = connector |
|---|
| 17 |
self.__wrappedFactory = wrappedFactory |
|---|
| 18 |
|
|---|
| 19 |
def startedConnecting(self, connector): |
|---|
| 20 |
self.__wrappedFactory.startedConnecting(self.__connector) |
|---|
| 21 |
|
|---|
| 22 |
def clientConnectionFailed(self, connector, reason): |
|---|
| 23 |
self.__connector.connectionFailed(reason) |
|---|
| 24 |
|
|---|
| 25 |
def clientConnectionLost(self, connector, reason): |
|---|
| 26 |
self.__connector.connectionLost(reason) |
|---|
| 27 |
|
|---|
| 28 |
def __getattr__(self, key): |
|---|
| 29 |
return getattr(self.__wrappedFactory, key) |
|---|
| 30 |
|
|---|
| 31 |
class SRVConnector: |
|---|
| 32 |
"""A connector that looks up DNS SRV records. See RFC2782.""" |
|---|
| 33 |
|
|---|
| 34 |
implements(interfaces.IConnector) |
|---|
| 35 |
|
|---|
| 36 |
stopAfterDNS=0 |
|---|
| 37 |
|
|---|
| 38 |
def __init__(self, reactor, service, domain, factory, |
|---|
| 39 |
protocol='tcp', connectFuncName='connectTCP', |
|---|
| 40 |
connectFuncArgs=(), |
|---|
| 41 |
connectFuncKwArgs={}, |
|---|
| 42 |
): |
|---|
| 43 |
self.reactor = reactor |
|---|
| 44 |
self.service = service |
|---|
| 45 |
self.domain = domain |
|---|
| 46 |
self.factory = factory |
|---|
| 47 |
|
|---|
| 48 |
self.protocol = protocol |
|---|
| 49 |
self.connectFuncName = connectFuncName |
|---|
| 50 |
self.connectFuncArgs = connectFuncArgs |
|---|
| 51 |
self.connectFuncKwArgs = connectFuncKwArgs |
|---|
| 52 |
|
|---|
| 53 |
self.connector = None |
|---|
| 54 |
self.servers = None |
|---|
| 55 |
self.orderedServers = None |
|---|
| 56 |
|
|---|
| 57 |
def connect(self): |
|---|
| 58 |
"""Start connection to remote server.""" |
|---|
| 59 |
self.factory.doStart() |
|---|
| 60 |
self.factory.startedConnecting(self) |
|---|
| 61 |
|
|---|
| 62 |
if not self.servers: |
|---|
| 63 |
if self.domain is None: |
|---|
| 64 |
self.connectionFailed(error.DNSLookupError("Domain is not defined.")) |
|---|
| 65 |
return |
|---|
| 66 |
d = client.lookupService('_%s._%s.%s' % (self.service, |
|---|
| 67 |
self.protocol, |
|---|
| 68 |
self.domain)) |
|---|
| 69 |
d.addCallbacks(self._cbGotServers, self._ebGotServers) |
|---|
| 70 |
d.addCallback(lambda x, self=self: self._reallyConnect()) |
|---|
| 71 |
d.addErrback(self.connectionFailed) |
|---|
| 72 |
elif self.connector is None: |
|---|
| 73 |
self._reallyConnect() |
|---|
| 74 |
else: |
|---|
| 75 |
self.connector.connect() |
|---|
| 76 |
|
|---|
| 77 |
def _ebGotServers(self, failure): |
|---|
| 78 |
failure.trap(DNSNameError) |
|---|
| 79 |
|
|---|
| 80 |
|
|---|
| 81 |
|
|---|
| 82 |
|
|---|
| 83 |
|
|---|
| 84 |
self.servers = [] |
|---|
| 85 |
self.orderedServers = [] |
|---|
| 86 |
|
|---|
| 87 |
def _cbGotServers(self, (answers, auth, add)): |
|---|
| 88 |
if len(answers) == 1 and answers[0].type == dns.SRV \ |
|---|
| 89 |
and answers[0].payload \ |
|---|
| 90 |
and answers[0].payload.target == dns.Name('.'): |
|---|
| 91 |
|
|---|
| 92 |
raise error.DNSLookupError("Service %s not available for domain %s." |
|---|
| 93 |
% (repr(self.service), repr(self.domain))) |
|---|
| 94 |
|
|---|
| 95 |
self.servers = [] |
|---|
| 96 |
self.orderedServers = [] |
|---|
| 97 |
for a in answers: |
|---|
| 98 |
if a.type != dns.SRV or not a.payload: |
|---|
| 99 |
continue |
|---|
| 100 |
|
|---|
| 101 |
self.orderedServers.append((a.payload.priority, a.payload.weight, |
|---|
| 102 |
str(a.payload.target), a.payload.port)) |
|---|
| 103 |
|
|---|
| 104 |
def _serverCmp(self, a, b): |
|---|
| 105 |
if a[0]!=b[0]: |
|---|
| 106 |
return cmp(a[0], b[0]) |
|---|
| 107 |
else: |
|---|
| 108 |
return cmp(a[1], b[1]) |
|---|
| 109 |
|
|---|
| 110 |
def pickServer(self): |
|---|
| 111 |
assert self.servers is not None |
|---|
| 112 |
assert self.orderedServers is not None |
|---|
| 113 |
|
|---|
| 114 |
if not self.servers and not self.orderedServers: |
|---|
| 115 |
|
|---|
| 116 |
return self.domain, self.service |
|---|
| 117 |
|
|---|
| 118 |
if not self.servers and self.orderedServers: |
|---|
| 119 |
|
|---|
| 120 |
self.servers = self.orderedServers |
|---|
| 121 |
self.orderedServers = [] |
|---|
| 122 |
|
|---|
| 123 |
assert self.servers |
|---|
| 124 |
|
|---|
| 125 |
self.servers.sort(self._serverCmp) |
|---|
| 126 |
minPriority=self.servers[0][0] |
|---|
| 127 |
|
|---|
| 128 |
weightIndex = zip(xrange(len(self.servers)), [x[1] for x in self.servers |
|---|
| 129 |
if x[0]==minPriority]) |
|---|
| 130 |
weightSum = reduce(lambda x, y: (None, x[1]+y[1]), weightIndex, (None, 0))[1] |
|---|
| 131 |
rand = random.randint(0, weightSum) |
|---|
| 132 |
|
|---|
| 133 |
for index, weight in weightIndex: |
|---|
| 134 |
weightSum -= weight |
|---|
| 135 |
if weightSum <= 0: |
|---|
| 136 |
chosen = self.servers[index] |
|---|
| 137 |
del self.servers[index] |
|---|
| 138 |
self.orderedServers.append(chosen) |
|---|
| 139 |
|
|---|
| 140 |
p, w, host, port = chosen |
|---|
| 141 |
return host, port |
|---|
| 142 |
|
|---|
| 143 |
raise RuntimeError, 'Impossible %s pickServer result.' % self.__class__.__name__ |
|---|
| 144 |
|
|---|
| 145 |
def _reallyConnect(self): |
|---|
| 146 |
if self.stopAfterDNS: |
|---|
| 147 |
self.stopAfterDNS=0 |
|---|
| 148 |
return |
|---|
| 149 |
|
|---|
| 150 |
self.host, self.port = self.pickServer() |
|---|
| 151 |
assert self.host is not None, 'Must have a host to connect to.' |
|---|
| 152 |
assert self.port is not None, 'Must have a port to connect to.' |
|---|
| 153 |
|
|---|
| 154 |
connectFunc = getattr(self.reactor, self.connectFuncName) |
|---|
| 155 |
self.connector=connectFunc( |
|---|
| 156 |
self.host, self.port, |
|---|
| 157 |
_SRVConnector_ClientFactoryWrapper(self, self.factory), |
|---|
| 158 |
*self.connectFuncArgs, **self.connectFuncKwArgs) |
|---|
| 159 |
|
|---|
| 160 |
def stopConnecting(self): |
|---|
| 161 |
"""Stop attempting to connect.""" |
|---|
| 162 |
if self.connector: |
|---|
| 163 |
self.connector.stopConnecting() |
|---|
| 164 |
else: |
|---|
| 165 |
self.stopAfterDNS=1 |
|---|
| 166 |
|
|---|
| 167 |
def disconnect(self): |
|---|
| 168 |
"""Disconnect whatever our are state is.""" |
|---|
| 169 |
if self.connector is not None: |
|---|
| 170 |
self.connector.disconnect() |
|---|
| 171 |
else: |
|---|
| 172 |
self.stopConnecting() |
|---|
| 173 |
|
|---|
| 174 |
def getDestination(self): |
|---|
| 175 |
assert self.connector |
|---|
| 176 |
return self.connector.getDestination() |
|---|
| 177 |
|
|---|
| 178 |
def connectionFailed(self, reason): |
|---|
| 179 |
self.factory.clientConnectionFailed(self, reason) |
|---|
| 180 |
self.factory.doStop() |
|---|
| 181 |
|
|---|
| 182 |
def connectionLost(self, reason): |
|---|
| 183 |
self.factory.clientConnectionLost(self, reason) |
|---|
| 184 |
self.factory.doStop() |
|---|
| 185 |
|
|---|