| 1 |
|
|---|
| 2 |
|
|---|
| 3 |
|
|---|
| 4 |
|
|---|
| 5 |
""" |
|---|
| 6 |
Testing support for protocols -- loopback between client and server. |
|---|
| 7 |
""" |
|---|
| 8 |
|
|---|
| 9 |
|
|---|
| 10 |
import tempfile |
|---|
| 11 |
from zope.interface import implements |
|---|
| 12 |
|
|---|
| 13 |
|
|---|
| 14 |
from twisted.protocols import policies |
|---|
| 15 |
from twisted.internet import interfaces, protocol, main, defer |
|---|
| 16 |
from twisted.internet.task import deferLater |
|---|
| 17 |
from twisted.python import failure |
|---|
| 18 |
from twisted.internet.interfaces import IAddress |
|---|
| 19 |
|
|---|
| 20 |
|
|---|
| 21 |
class _LoopbackQueue(object): |
|---|
| 22 |
""" |
|---|
| 23 |
Trivial wrapper around a list to give it an interface like a queue, which |
|---|
| 24 |
the addition of also sending notifications by way of a Deferred whenever |
|---|
| 25 |
the list has something added to it. |
|---|
| 26 |
""" |
|---|
| 27 |
|
|---|
| 28 |
_notificationDeferred = None |
|---|
| 29 |
disconnect = False |
|---|
| 30 |
|
|---|
| 31 |
def __init__(self): |
|---|
| 32 |
self._queue = [] |
|---|
| 33 |
|
|---|
| 34 |
|
|---|
| 35 |
def put(self, v): |
|---|
| 36 |
self._queue.append(v) |
|---|
| 37 |
if self._notificationDeferred is not None: |
|---|
| 38 |
d, self._notificationDeferred = self._notificationDeferred, None |
|---|
| 39 |
d.callback(None) |
|---|
| 40 |
|
|---|
| 41 |
|
|---|
| 42 |
def __nonzero__(self): |
|---|
| 43 |
return bool(self._queue) |
|---|
| 44 |
|
|---|
| 45 |
|
|---|
| 46 |
def get(self): |
|---|
| 47 |
return self._queue.pop(0) |
|---|
| 48 |
|
|---|
| 49 |
|
|---|
| 50 |
|
|---|
| 51 |
class _LoopbackAddress(object): |
|---|
| 52 |
implements(IAddress) |
|---|
| 53 |
|
|---|
| 54 |
|
|---|
| 55 |
class _LoopbackTransport(object): |
|---|
| 56 |
implements(interfaces.ITransport, interfaces.IConsumer) |
|---|
| 57 |
|
|---|
| 58 |
disconnecting = False |
|---|
| 59 |
producer = None |
|---|
| 60 |
|
|---|
| 61 |
|
|---|
| 62 |
def __init__(self, q): |
|---|
| 63 |
self.q = q |
|---|
| 64 |
|
|---|
| 65 |
def write(self, bytes): |
|---|
| 66 |
self.q.put(bytes) |
|---|
| 67 |
|
|---|
| 68 |
def writeSequence(self, iovec): |
|---|
| 69 |
self.q.put(''.join(iovec)) |
|---|
| 70 |
|
|---|
| 71 |
def loseConnection(self): |
|---|
| 72 |
self.q.disconnect = True |
|---|
| 73 |
self.q.put(None) |
|---|
| 74 |
|
|---|
| 75 |
def getPeer(self): |
|---|
| 76 |
return _LoopbackAddress() |
|---|
| 77 |
|
|---|
| 78 |
def getHost(self): |
|---|
| 79 |
return _LoopbackAddress() |
|---|
| 80 |
|
|---|
| 81 |
|
|---|
| 82 |
def registerProducer(self, producer, streaming): |
|---|
| 83 |
assert self.producer is None |
|---|
| 84 |
self.producer = producer |
|---|
| 85 |
self.streamingProducer = streaming |
|---|
| 86 |
self._pollProducer() |
|---|
| 87 |
|
|---|
| 88 |
def unregisterProducer(self): |
|---|
| 89 |
assert self.producer is not None |
|---|
| 90 |
self.producer = None |
|---|
| 91 |
|
|---|
| 92 |
def _pollProducer(self): |
|---|
| 93 |
if self.producer is not None and not self.streamingProducer: |
|---|
| 94 |
self.producer.resumeProducing() |
|---|
| 95 |
|
|---|
| 96 |
|
|---|
| 97 |
|
|---|
| 98 |
def identityPumpPolicy(queue, target): |
|---|
| 99 |
""" |
|---|
| 100 |
L{identityPumpPolicy} is a policy which delivers each chunk of data written |
|---|
| 101 |
to the given queue as-is to the target. |
|---|
| 102 |
|
|---|
| 103 |
This isn't a particularly realistic policy. |
|---|
| 104 |
|
|---|
| 105 |
@see: L{loopbackAsync} |
|---|
| 106 |
""" |
|---|
| 107 |
while queue: |
|---|
| 108 |
bytes = queue.get() |
|---|
| 109 |
if bytes is None: |
|---|
| 110 |
break |
|---|
| 111 |
target.dataReceived(bytes) |
|---|
| 112 |
|
|---|
| 113 |
|
|---|
| 114 |
|
|---|
| 115 |
def collapsingPumpPolicy(queue, target): |
|---|
| 116 |
""" |
|---|
| 117 |
L{collapsingPumpPolicy} is a policy which collapses all outstanding chunks |
|---|
| 118 |
into a single string and delivers it to the target. |
|---|
| 119 |
|
|---|
| 120 |
@see: L{loopbackAsync} |
|---|
| 121 |
""" |
|---|
| 122 |
bytes = [] |
|---|
| 123 |
while queue: |
|---|
| 124 |
chunk = queue.get() |
|---|
| 125 |
if chunk is None: |
|---|
| 126 |
break |
|---|
| 127 |
bytes.append(chunk) |
|---|
| 128 |
if bytes: |
|---|
| 129 |
target.dataReceived(''.join(bytes)) |
|---|
| 130 |
|
|---|
| 131 |
|
|---|
| 132 |
|
|---|
| 133 |
def loopbackAsync(server, client, pumpPolicy=identityPumpPolicy): |
|---|
| 134 |
""" |
|---|
| 135 |
Establish a connection between C{server} and C{client} then transfer data |
|---|
| 136 |
between them until the connection is closed. This is often useful for |
|---|
| 137 |
testing a protocol. |
|---|
| 138 |
|
|---|
| 139 |
@param server: The protocol instance representing the server-side of this |
|---|
| 140 |
connection. |
|---|
| 141 |
|
|---|
| 142 |
@param client: The protocol instance representing the client-side of this |
|---|
| 143 |
connection. |
|---|
| 144 |
|
|---|
| 145 |
@param pumpPolicy: When either C{server} or C{client} writes to its |
|---|
| 146 |
transport, the string passed in is added to a queue of data for the |
|---|
| 147 |
other protocol. Eventually, C{pumpPolicy} will be called with one such |
|---|
| 148 |
queue and the corresponding protocol object. The pump policy callable |
|---|
| 149 |
is responsible for emptying the queue and passing the strings it |
|---|
| 150 |
contains to the given protocol's C{dataReceived} method. The signature |
|---|
| 151 |
of C{pumpPolicy} is C{(queue, protocol)}. C{queue} is an object with a |
|---|
| 152 |
C{get} method which will return the next string written to the |
|---|
| 153 |
transport, or C{None} if the transport has been disconnected, and which |
|---|
| 154 |
evaluates to C{True} if and only if there are more items to be |
|---|
| 155 |
retrieved via C{get}. |
|---|
| 156 |
|
|---|
| 157 |
@return: A L{Deferred} which fires when the connection has been closed and |
|---|
| 158 |
both sides have received notification of this. |
|---|
| 159 |
""" |
|---|
| 160 |
serverToClient = _LoopbackQueue() |
|---|
| 161 |
clientToServer = _LoopbackQueue() |
|---|
| 162 |
|
|---|
| 163 |
server.makeConnection(_LoopbackTransport(serverToClient)) |
|---|
| 164 |
client.makeConnection(_LoopbackTransport(clientToServer)) |
|---|
| 165 |
|
|---|
| 166 |
return _loopbackAsyncBody( |
|---|
| 167 |
server, serverToClient, client, clientToServer, pumpPolicy) |
|---|
| 168 |
|
|---|
| 169 |
|
|---|
| 170 |
|
|---|
| 171 |
def _loopbackAsyncBody(server, serverToClient, client, clientToServer, |
|---|
| 172 |
pumpPolicy): |
|---|
| 173 |
""" |
|---|
| 174 |
Transfer bytes from the output queue of each protocol to the input of the other. |
|---|
| 175 |
|
|---|
| 176 |
@param server: The protocol instance representing the server-side of this |
|---|
| 177 |
connection. |
|---|
| 178 |
|
|---|
| 179 |
@param serverToClient: The L{_LoopbackQueue} holding the server's output. |
|---|
| 180 |
|
|---|
| 181 |
@param client: The protocol instance representing the client-side of this |
|---|
| 182 |
connection. |
|---|
| 183 |
|
|---|
| 184 |
@param clientToServer: The L{_LoopbackQueue} holding the client's output. |
|---|
| 185 |
|
|---|
| 186 |
@param pumpPolicy: See L{loopbackAsync}. |
|---|
| 187 |
|
|---|
| 188 |
@return: A L{Deferred} which fires when the connection has been closed and |
|---|
| 189 |
both sides have received notification of this. |
|---|
| 190 |
""" |
|---|
| 191 |
def pump(source, q, target): |
|---|
| 192 |
sent = False |
|---|
| 193 |
if q: |
|---|
| 194 |
pumpPolicy(q, target) |
|---|
| 195 |
sent = True |
|---|
| 196 |
if sent and not q: |
|---|
| 197 |
|
|---|
| 198 |
|
|---|
| 199 |
source.transport._pollProducer() |
|---|
| 200 |
|
|---|
| 201 |
return sent |
|---|
| 202 |
|
|---|
| 203 |
while 1: |
|---|
| 204 |
disconnect = clientSent = serverSent = False |
|---|
| 205 |
|
|---|
| 206 |
|
|---|
| 207 |
serverSent = pump(server, serverToClient, client) |
|---|
| 208 |
clientSent = pump(client, clientToServer, server) |
|---|
| 209 |
|
|---|
| 210 |
if not clientSent and not serverSent: |
|---|
| 211 |
|
|---|
| 212 |
|
|---|
| 213 |
d = defer.Deferred() |
|---|
| 214 |
clientToServer._notificationDeferred = d |
|---|
| 215 |
serverToClient._notificationDeferred = d |
|---|
| 216 |
d.addCallback( |
|---|
| 217 |
_loopbackAsyncContinue, |
|---|
| 218 |
server, serverToClient, client, clientToServer, pumpPolicy) |
|---|
| 219 |
return d |
|---|
| 220 |
if serverToClient.disconnect: |
|---|
| 221 |
|
|---|
| 222 |
|
|---|
| 223 |
disconnect = True |
|---|
| 224 |
pump(server, serverToClient, client) |
|---|
| 225 |
elif clientToServer.disconnect: |
|---|
| 226 |
|
|---|
| 227 |
|
|---|
| 228 |
disconnect = True |
|---|
| 229 |
pump(client, clientToServer, server) |
|---|
| 230 |
if disconnect: |
|---|
| 231 |
|
|---|
| 232 |
server.connectionLost(failure.Failure(main.CONNECTION_DONE)) |
|---|
| 233 |
client.connectionLost(failure.Failure(main.CONNECTION_DONE)) |
|---|
| 234 |
return defer.succeed(None) |
|---|
| 235 |
|
|---|
| 236 |
|
|---|
| 237 |
|
|---|
| 238 |
def _loopbackAsyncContinue(ignored, server, serverToClient, client, |
|---|
| 239 |
clientToServer, pumpPolicy): |
|---|
| 240 |
|
|---|
| 241 |
|
|---|
| 242 |
clientToServer._notificationDeferred = None |
|---|
| 243 |
serverToClient._notificationDeferred = None |
|---|
| 244 |
|
|---|
| 245 |
|
|---|
| 246 |
|
|---|
| 247 |
|
|---|
| 248 |
|
|---|
| 249 |
from twisted.internet import reactor |
|---|
| 250 |
return deferLater( |
|---|
| 251 |
reactor, 0, |
|---|
| 252 |
_loopbackAsyncBody, |
|---|
| 253 |
server, serverToClient, client, clientToServer, pumpPolicy) |
|---|
| 254 |
|
|---|
| 255 |
|
|---|
| 256 |
|
|---|
| 257 |
class LoopbackRelay: |
|---|
| 258 |
|
|---|
| 259 |
implements(interfaces.ITransport, interfaces.IConsumer) |
|---|
| 260 |
|
|---|
| 261 |
buffer = '' |
|---|
| 262 |
shouldLose = 0 |
|---|
| 263 |
disconnecting = 0 |
|---|
| 264 |
producer = None |
|---|
| 265 |
|
|---|
| 266 |
def __init__(self, target, logFile=None): |
|---|
| 267 |
self.target = target |
|---|
| 268 |
self.logFile = logFile |
|---|
| 269 |
|
|---|
| 270 |
def write(self, data): |
|---|
| 271 |
self.buffer = self.buffer + data |
|---|
| 272 |
if self.logFile: |
|---|
| 273 |
self.logFile.write("loopback writing %s\n" % repr(data)) |
|---|
| 274 |
|
|---|
| 275 |
def writeSequence(self, iovec): |
|---|
| 276 |
self.write("".join(iovec)) |
|---|
| 277 |
|
|---|
| 278 |
def clearBuffer(self): |
|---|
| 279 |
if self.shouldLose == -1: |
|---|
| 280 |
return |
|---|
| 281 |
|
|---|
| 282 |
if self.producer: |
|---|
| 283 |
self.producer.resumeProducing() |
|---|
| 284 |
if self.buffer: |
|---|
| 285 |
if self.logFile: |
|---|
| 286 |
self.logFile.write("loopback receiving %s\n" % repr(self.buffer)) |
|---|
| 287 |
buffer = self.buffer |
|---|
| 288 |
self.buffer = '' |
|---|
| 289 |
self.target.dataReceived(buffer) |
|---|
| 290 |
if self.shouldLose == 1: |
|---|
| 291 |
self.shouldLose = -1 |
|---|
| 292 |
self.target.connectionLost(failure.Failure(main.CONNECTION_DONE)) |
|---|
| 293 |
|
|---|
| 294 |
def loseConnection(self): |
|---|
| 295 |
if self.shouldLose != -1: |
|---|
| 296 |
self.shouldLose = 1 |
|---|
| 297 |
|
|---|
| 298 |
def getHost(self): |
|---|
| 299 |
return 'loopback' |
|---|
| 300 |
|
|---|
| 301 |
def getPeer(self): |
|---|
| 302 |
return 'loopback' |
|---|
| 303 |
|
|---|
| 304 |
def registerProducer(self, producer, streaming): |
|---|
| 305 |
self.producer = producer |
|---|
| 306 |
|
|---|
| 307 |
def unregisterProducer(self): |
|---|
| 308 |
self.producer = None |
|---|
| 309 |
|
|---|
| 310 |
def logPrefix(self): |
|---|
| 311 |
return 'Loopback(%r)' % (self.target.__class__.__name__,) |
|---|
| 312 |
|
|---|
| 313 |
def loopback(server, client, logFile=None): |
|---|
| 314 |
"""Run session between server and client. |
|---|
| 315 |
DEPRECATED in Twisted 2.5. Use loopbackAsync instead. |
|---|
| 316 |
""" |
|---|
| 317 |
import warnings |
|---|
| 318 |
warnings.warn('loopback() is deprecated (since Twisted 2.5). ' |
|---|
| 319 |
'Use loopbackAsync() instead.', |
|---|
| 320 |
stacklevel=2, category=DeprecationWarning) |
|---|
| 321 |
from twisted.internet import reactor |
|---|
| 322 |
serverToClient = LoopbackRelay(client, logFile) |
|---|
| 323 |
clientToServer = LoopbackRelay(server, logFile) |
|---|
| 324 |
server.makeConnection(serverToClient) |
|---|
| 325 |
client.makeConnection(clientToServer) |
|---|
| 326 |
while 1: |
|---|
| 327 |
reactor.iterate(0.01) |
|---|
| 328 |
serverToClient.clearBuffer() |
|---|
| 329 |
clientToServer.clearBuffer() |
|---|
| 330 |
if serverToClient.shouldLose: |
|---|
| 331 |
serverToClient.clearBuffer() |
|---|
| 332 |
server.connectionLost(failure.Failure(main.CONNECTION_DONE)) |
|---|
| 333 |
break |
|---|
| 334 |
elif clientToServer.shouldLose: |
|---|
| 335 |
client.connectionLost(failure.Failure(main.CONNECTION_DONE)) |
|---|
| 336 |
break |
|---|
| 337 |
reactor.iterate() |
|---|
| 338 |
|
|---|
| 339 |
|
|---|
| 340 |
class LoopbackClientFactory(protocol.ClientFactory): |
|---|
| 341 |
|
|---|
| 342 |
def __init__(self, protocol): |
|---|
| 343 |
self.disconnected = 0 |
|---|
| 344 |
self.deferred = defer.Deferred() |
|---|
| 345 |
self.protocol = protocol |
|---|
| 346 |
|
|---|
| 347 |
def buildProtocol(self, addr): |
|---|
| 348 |
return self.protocol |
|---|
| 349 |
|
|---|
| 350 |
def clientConnectionLost(self, connector, reason): |
|---|
| 351 |
self.disconnected = 1 |
|---|
| 352 |
self.deferred.callback(None) |
|---|
| 353 |
|
|---|
| 354 |
|
|---|
| 355 |
class _FireOnClose(policies.ProtocolWrapper): |
|---|
| 356 |
def __init__(self, protocol, factory): |
|---|
| 357 |
policies.ProtocolWrapper.__init__(self, protocol, factory) |
|---|
| 358 |
self.deferred = defer.Deferred() |
|---|
| 359 |
|
|---|
| 360 |
def connectionLost(self, reason): |
|---|
| 361 |
policies.ProtocolWrapper.connectionLost(self, reason) |
|---|
| 362 |
self.deferred.callback(None) |
|---|
| 363 |
|
|---|
| 364 |
|
|---|
| 365 |
def loopbackTCP(server, client, port=0, noisy=True): |
|---|
| 366 |
"""Run session between server and client protocol instances over TCP.""" |
|---|
| 367 |
from twisted.internet import reactor |
|---|
| 368 |
f = policies.WrappingFactory(protocol.Factory()) |
|---|
| 369 |
serverWrapper = _FireOnClose(f, server) |
|---|
| 370 |
f.noisy = noisy |
|---|
| 371 |
f.buildProtocol = lambda addr: serverWrapper |
|---|
| 372 |
serverPort = reactor.listenTCP(port, f, interface='127.0.0.1') |
|---|
| 373 |
clientF = LoopbackClientFactory(client) |
|---|
| 374 |
clientF.noisy = noisy |
|---|
| 375 |
reactor.connectTCP('127.0.0.1', serverPort.getHost().port, clientF) |
|---|
| 376 |
d = clientF.deferred |
|---|
| 377 |
d.addCallback(lambda x: serverWrapper.deferred) |
|---|
| 378 |
d.addCallback(lambda x: serverPort.stopListening()) |
|---|
| 379 |
return d |
|---|
| 380 |
|
|---|
| 381 |
|
|---|
| 382 |
def loopbackUNIX(server, client, noisy=True): |
|---|
| 383 |
"""Run session between server and client protocol instances over UNIX socket.""" |
|---|
| 384 |
path = tempfile.mktemp() |
|---|
| 385 |
from twisted.internet import reactor |
|---|
| 386 |
f = policies.WrappingFactory(protocol.Factory()) |
|---|
| 387 |
serverWrapper = _FireOnClose(f, server) |
|---|
| 388 |
f.noisy = noisy |
|---|
| 389 |
f.buildProtocol = lambda addr: serverWrapper |
|---|
| 390 |
serverPort = reactor.listenUNIX(path, f) |
|---|
| 391 |
clientF = LoopbackClientFactory(client) |
|---|
| 392 |
clientF.noisy = noisy |
|---|
| 393 |
reactor.connectUNIX(path, clientF) |
|---|
| 394 |
d = clientF.deferred |
|---|
| 395 |
d.addCallback(lambda x: serverWrapper.deferred) |
|---|
| 396 |
d.addCallback(lambda x: serverPort.stopListening()) |
|---|
| 397 |
return d |
|---|