Ticket #990: defer-cancel.patch

File defer-cancel.patch, 10.8 KB (added by jknight, 12 years ago)
  • twisted/internet/defer.py

     
    2525class AlreadyArmedError(Exception):
    2626    pass
    2727
    28 class TimeoutError(Exception):
     28class CancelledError(Exception):
    2929    pass
     30# Backwards compatibility
     31TimeoutError = CancelledError
    3032
    3133def logError(err):
    3234    log.err(err)
     
    139141    return deferred
    140142
    141143def timeout(deferred):
    142     deferred.errback(failure.Failure(TimeoutError("Callback timed out")))
     144    deferred.cancel()
    143145
    144146def passthru(arg):
    145147    return arg
     
    166168
    167169    For more information about Deferreds, see doc/howto/defer.html or
    168170    U{http://www.twistedmatrix.com/documents/howto/defer}
     171
     172    When creating a Deferred, you may provide a canceller function,
     173    which will be called by d.cancel() to let you do any cleanup necessary
     174    if the user decides not to wait for the deferred to complete.
    169175    """
    170176    called = 0
    171177    default = 0
    172178    paused = 0
    173179    timeoutCall = None
    174180    _debugInfo = None
    175 
     181    suppressAlreadyCalled = 0
     182   
    176183    # Keep this class attribute for now, for compatibility with code that
    177184    # sets it directly.
    178185    debug = False
    179186
    180     def __init__(self):
     187    def __init__(self, canceller=None):
    181188        self.callbacks = []
     189        self.canceller = canceller
    182190        if self.debug:
    183191            self._debugInfo = DebugInfo()
    184192            self._debugInfo.creator = traceback.format_stack()[:-1]
     
    278286
    279287    def pause(self):
    280288        """Stop processing on a Deferred until L{unpause}() is called.
     289        You probably don't ever have a reason to call this function.
    281290        """
    282291        self.paused = self.paused + 1
    283292
    284293
    285294    def unpause(self):
    286295        """Process all callbacks made since L{pause}() was called.
     296        You probably don't ever have a reason to call this function.
    287297        """
    288298        self.paused = self.paused - 1
    289299        if self.paused:
     
    291301        if self.called:
    292302            self._runCallbacks()
    293303
     304    def cancel(self):
     305        """Cancel this deferred.
     306
     307        If the deferred is waiting on another deferred, forward the
     308        cancellation to the other deferred.
     309
     310        If the deferred has not yet been errback'd/callback'd, call
     311        the canceller function provided to the constructor. If that
     312        function does not do a callback/errback, or if no canceller
     313        function was provided, errback with CancelledError.
     314
     315        Otherwise, raise AlreadyCalledError.
     316        """
     317        canceller=self.canceller
     318        if not self.called:
     319            if canceller:
     320                canceller(self)
     321            else:
     322                # Eat the callback that will eventually be fired
     323                # since there was no real canceller.
     324                self.suppressAlreadyCalled = 1
     325
     326            if not self.called:
     327                # The canceller didn't do an errback of its own
     328                try:
     329                    raise CancelledError
     330                except:
     331                    self.errback(failure.Failure())
     332        elif isinstance(self.result, Deferred):
     333            # Waiting for another deferred -- cancel it instead
     334            self.result.cancel()
     335        else:
     336            # Called and not waiting for another deferred
     337            raise AlreadyCalledError
     338   
    294339    def _continue(self, result):
    295340        self.result = result
    296341        self.unpause()
    297342
    298343    def _startRunCallbacks(self, result):
     344        # Canceller is no longer relevant
     345        self.canceller=None
     346       
    299347        if self.called:
     348            if self.suppressAlreadyCalled:
     349                self.suppressAlreadyCalled = False
     350                return
    300351            if self.debug:
    301352                extra = "\n" + self._debugInfo._getDebugTracebacks()
    302353                raise AlreadyCalledError(extra)
     
    363414
    364415        @param timeoutFunc: will receive the Deferred and *args, **kw as its
    365416        arguments.  The default timeoutFunc will call the errback with a
    366         L{TimeoutError}.
     417        L{CancelledError}.
     418
    367419        """
    368420        warnings.warn(
    369421            "Deferred.setTimeout is deprecated.  Look for timeout "
     
    699751
    700752    locked = 0
    701753
     754    def _cancelAcquire(self, d):
     755        self.waiting.remove(d)
     756       
    702757    def acquire(self):
    703758        """Attempt to acquire the lock.
    704759
    705760        @return: a Deferred which fires on lock acquisition.
    706761        """
    707         d = Deferred()
     762        d = Deferred(canceller=self._cancelAcquire)
    708763        if self.locked:
    709764            self.waiting.append(d)
    710765        else:
     
    736791        _ConcurrencyPrimitive.__init__(self)
    737792        self.tokens = tokens
    738793        self.limit = tokens
    739 
     794       
     795    def _cancelAcquire(self, d):
     796        self.waiting.remove(d)
     797       
    740798    def acquire(self):
    741799        """Attempt to acquire the token.
    742800
    743801        @return: a Deferred which fires on token acquisition.
    744802        """
    745803        assert self.tokens >= 0, "Internal inconsistency??  tokens should never be negative"
    746         d = Deferred()
     804        d = Deferred(canceller=self._cancelAcquire)
    747805        if not self.tokens:
    748806            self.waiting.append(d)
    749807        else:
     
    797855        self.size = size
    798856        self.backlog = backlog
    799857
     858    def _cancelGet(self, d):
     859        self.waiting.remove(d)
     860   
    800861    def put(self, obj):
    801862        """Add an object to this queue.
    802863
     
    820881        if self.pending:
    821882            return succeed(self.pending.pop(0))
    822883        elif self.backlog is None or len(self.waiting) < self.backlog:
    823             d = Deferred()
     884            d = Deferred(canceller=self._cancelGet)
    824885            self.waiting.append(d)
    825886            return d
    826887        else:
     
    828889
    829890
    830891__all__ = ["Deferred", "DeferredList", "succeed", "fail", "FAILURE", "SUCCESS",
    831            "AlreadyCalledError", "TimeoutError", "gatherResults",
     892           "AlreadyCalledError", "TimeoutError", "CancelledError", "gatherResults",
    832893           "maybeDeferred", "waitForDeferred", "deferredGenerator",
    833894           "DeferredLock", "DeferredSemaphore", "DeferredQueue",
    834895          ]
  • twisted/test/test_defer.py

     
    449449        else:
    450450            self.fail("second callback failed to raise AlreadyCalledError")
    451451
     452class FooError(Exception):
     453    pass
    452454
     455class DeferredCancellerTest(unittest.TestCase):
     456    def setUp(self):
     457        self.callback_results = None
     458        self.errback_results = None
     459        self.callback2_results = None
     460        self.cancellerCalled = False
     461       
     462    def _callback(self, data):
     463        self.callback_results = data
     464        return args[0]
     465
     466    def _callback2(self, data):
     467        self.callback2_results = data
     468
     469    def _errback(self, data):
     470        self.errback_results = data
     471
     472   
     473    def testNoCanceller(self):
     474        # Deferred without a canceller errbacks defer.CancelledError
     475        d=defer.Deferred()
     476        d.addCallbacks(self._callback, self._errback)
     477        d.cancel()
     478        self.assertEquals(self.errback_results.type, defer.CancelledError)
     479
     480        # Test that further callbacks *are* swallowed
     481        d.callback(None)
     482
     483        # But that a second is not
     484        self.assertRaises(defer.AlreadyCalledError, d.callback, None)
     485       
     486    def testCanceller(self):
     487        def cancel(d):
     488            self.cancellerCalled=True
     489           
     490        d=defer.Deferred(canceller=cancel)
     491        d.addCallbacks(self._callback, self._errback)
     492        d.cancel()
     493        self.assertEquals(self.cancellerCalled, True)
     494        self.assertEquals(self.errback_results.type, defer.CancelledError)
     495
     496        # Test that further callbacks are *not* swallowed
     497        self.assertRaises(defer.AlreadyCalledError, d.callback, None)
     498       
     499    def testCancellerWithCallback(self):
     500        # If we explicitly callback from the canceller, don't callback CancelledError
     501        def cancel(d):
     502            self.cancellerCalled=True
     503            d.errback(FooError())
     504        d=defer.Deferred(canceller=cancel)
     505        d.addCallbacks(self._callback, self._errback)
     506        d.cancel()
     507        self.assertEquals(self.cancellerCalled, True)
     508        self.assertEquals(self.errback_results.type, FooError)
     509
     510    def testCancelAlreadyCalled(self):
     511        def cancel(d):
     512            self.cancellerCalled=True
     513        d=defer.Deferred(canceller=cancel)
     514        d.callback(None)
     515        self.assertRaises(defer.AlreadyCalledError, d.cancel)
     516        self.assertEquals(self.cancellerCalled, False)
     517   
     518    def testCancelNestedDeferred(self):
     519        def innerCancel(d):
     520            self.assertIdentical(d, innerDeferred)
     521            self.cancellerCalled=True
     522        def cancel(d):
     523            self.assert_(False)
     524           
     525        innerDeferred=defer.Deferred(canceller=innerCancel)
     526        d=defer.Deferred(canceller=cancel)
     527        d.callback(None)
     528        d.addCallback(lambda data: innerDeferred)
     529        d.cancel()
     530        d.addCallbacks(self._callback, self._errback)
     531        self.assertEquals(self.cancellerCalled, True)
     532        self.assertEquals(self.errback_results.type, defer.CancelledError)
     533
    453534class LogTestCase(unittest.TestCase):
    454535
    455536    def setUp(self):
     
    539620        self.failUnless(lock.locked)
    540621        self.assertEquals(self.counter, 3)
    541622
     623        d = lock.acquire().addBoth(lambda x: setattr(self, 'result', x))
     624        d.cancel()
     625        self.assertEquals(self.result.type, defer.CancelledError)
     626       
    542627        lock.release()
    543628        self.failIf(lock.locked)
    544629
     
    567652            sem.acquire().addCallback(self._incr)
    568653            self.assertEquals(self.counter, i)
    569654
     655
     656        success = []
     657        def fail(r):
     658            success.append(False)
     659        def succeed(r):
     660            success.append(True)
     661        d = sem.acquire().addCallbacks(fail, succeed)
     662        d.cancel()
     663        self.assertEquals(success, [True])
     664       
    570665        sem.acquire().addCallback(self._incr)
    571666        self.assertEquals(self.counter, N)
    572667
    573668        sem.release()
    574669        self.assertEquals(self.counter, N + 1)
    575 
     670       
    576671        for i in range(1, 1 + N):
    577672            sem.release()
    578673            self.assertEquals(self.counter, N + 1)
     
    614709        queue = defer.DeferredQueue(backlog=0)
    615710        self.assertRaises(defer.QueueUnderflow, queue.get)
    616711
     712        queue = defer.DeferredQueue(size=0)
     713
     714        success = []
     715        def fail(r):
     716            success.append(False)
     717        def succeed(r):
     718            success.append(True)
     719        d = queue.get().addCallbacks(fail, succeed)
     720        d.cancel()
     721        self.assertEquals(success, [True])
     722        self.assertRaises(defer.QueueOverflow, queue.put, None)