Ticket #990: defer-cancel.patch

File defer-cancel.patch, 10.8 KB (added by jknight, 9 years ago)
  • twisted/internet/defer.py

     
    2525class AlreadyArmedError(Exception): 
    2626    pass 
    2727 
    28 class TimeoutError(Exception): 
     28class CancelledError(Exception): 
    2929    pass 
     30# Backwards compatibility 
     31TimeoutError = CancelledError 
    3032 
    3133def logError(err): 
    3234    log.err(err) 
     
    139141    return deferred 
    140142 
    141143def timeout(deferred): 
    142     deferred.errback(failure.Failure(TimeoutError("Callback timed out"))) 
     144    deferred.cancel() 
    143145 
    144146def passthru(arg): 
    145147    return arg 
     
    166168 
    167169    For more information about Deferreds, see doc/howto/defer.html or 
    168170    U{http://www.twistedmatrix.com/documents/howto/defer} 
     171 
     172    When creating a Deferred, you may provide a canceller function, 
     173    which will be called by d.cancel() to let you do any cleanup necessary 
     174    if the user decides not to wait for the deferred to complete. 
    169175    """ 
    170176    called = 0 
    171177    default = 0 
    172178    paused = 0 
    173179    timeoutCall = None 
    174180    _debugInfo = None 
    175  
     181    suppressAlreadyCalled = 0 
     182     
    176183    # Keep this class attribute for now, for compatibility with code that 
    177184    # sets it directly. 
    178185    debug = False 
    179186 
    180     def __init__(self): 
     187    def __init__(self, canceller=None): 
    181188        self.callbacks = [] 
     189        self.canceller = canceller 
    182190        if self.debug: 
    183191            self._debugInfo = DebugInfo() 
    184192            self._debugInfo.creator = traceback.format_stack()[:-1] 
     
    278286 
    279287    def pause(self): 
    280288        """Stop processing on a Deferred until L{unpause}() is called. 
     289        You probably don't ever have a reason to call this function. 
    281290        """ 
    282291        self.paused = self.paused + 1 
    283292 
    284293 
    285294    def unpause(self): 
    286295        """Process all callbacks made since L{pause}() was called. 
     296        You probably don't ever have a reason to call this function. 
    287297        """ 
    288298        self.paused = self.paused - 1 
    289299        if self.paused: 
     
    291301        if self.called: 
    292302            self._runCallbacks() 
    293303 
     304    def cancel(self): 
     305        """Cancel this deferred. 
     306 
     307        If the deferred is waiting on another deferred, forward the 
     308        cancellation to the other deferred. 
     309 
     310        If the deferred has not yet been errback'd/callback'd, call 
     311        the canceller function provided to the constructor. If that 
     312        function does not do a callback/errback, or if no canceller 
     313        function was provided, errback with CancelledError. 
     314 
     315        Otherwise, raise AlreadyCalledError. 
     316        """ 
     317        canceller=self.canceller 
     318        if not self.called: 
     319            if canceller: 
     320                canceller(self) 
     321            else: 
     322                # Eat the callback that will eventually be fired 
     323                # since there was no real canceller. 
     324                self.suppressAlreadyCalled = 1 
     325 
     326            if not self.called: 
     327                # The canceller didn't do an errback of its own 
     328                try: 
     329                    raise CancelledError 
     330                except: 
     331                    self.errback(failure.Failure()) 
     332        elif isinstance(self.result, Deferred): 
     333            # Waiting for another deferred -- cancel it instead 
     334            self.result.cancel() 
     335        else: 
     336            # Called and not waiting for another deferred 
     337            raise AlreadyCalledError 
     338     
    294339    def _continue(self, result): 
    295340        self.result = result 
    296341        self.unpause() 
    297342 
    298343    def _startRunCallbacks(self, result): 
     344        # Canceller is no longer relevant 
     345        self.canceller=None 
     346         
    299347        if self.called: 
     348            if self.suppressAlreadyCalled: 
     349                self.suppressAlreadyCalled = False 
     350                return 
    300351            if self.debug: 
    301352                extra = "\n" + self._debugInfo._getDebugTracebacks() 
    302353                raise AlreadyCalledError(extra) 
     
    363414 
    364415        @param timeoutFunc: will receive the Deferred and *args, **kw as its 
    365416        arguments.  The default timeoutFunc will call the errback with a 
    366         L{TimeoutError}. 
     417        L{CancelledError}. 
     418 
    367419        """ 
    368420        warnings.warn( 
    369421            "Deferred.setTimeout is deprecated.  Look for timeout " 
     
    699751 
    700752    locked = 0 
    701753 
     754    def _cancelAcquire(self, d): 
     755        self.waiting.remove(d) 
     756         
    702757    def acquire(self): 
    703758        """Attempt to acquire the lock. 
    704759 
    705760        @return: a Deferred which fires on lock acquisition. 
    706761        """ 
    707         d = Deferred() 
     762        d = Deferred(canceller=self._cancelAcquire) 
    708763        if self.locked: 
    709764            self.waiting.append(d) 
    710765        else: 
     
    736791        _ConcurrencyPrimitive.__init__(self) 
    737792        self.tokens = tokens 
    738793        self.limit = tokens 
    739  
     794         
     795    def _cancelAcquire(self, d): 
     796        self.waiting.remove(d) 
     797         
    740798    def acquire(self): 
    741799        """Attempt to acquire the token. 
    742800 
    743801        @return: a Deferred which fires on token acquisition. 
    744802        """ 
    745803        assert self.tokens >= 0, "Internal inconsistency??  tokens should never be negative" 
    746         d = Deferred() 
     804        d = Deferred(canceller=self._cancelAcquire) 
    747805        if not self.tokens: 
    748806            self.waiting.append(d) 
    749807        else: 
     
    797855        self.size = size 
    798856        self.backlog = backlog 
    799857 
     858    def _cancelGet(self, d): 
     859        self.waiting.remove(d) 
     860     
    800861    def put(self, obj): 
    801862        """Add an object to this queue. 
    802863 
     
    820881        if self.pending: 
    821882            return succeed(self.pending.pop(0)) 
    822883        elif self.backlog is None or len(self.waiting) < self.backlog: 
    823             d = Deferred() 
     884            d = Deferred(canceller=self._cancelGet) 
    824885            self.waiting.append(d) 
    825886            return d 
    826887        else: 
     
    828889 
    829890 
    830891__all__ = ["Deferred", "DeferredList", "succeed", "fail", "FAILURE", "SUCCESS", 
    831            "AlreadyCalledError", "TimeoutError", "gatherResults", 
     892           "AlreadyCalledError", "TimeoutError", "CancelledError", "gatherResults", 
    832893           "maybeDeferred", "waitForDeferred", "deferredGenerator", 
    833894           "DeferredLock", "DeferredSemaphore", "DeferredQueue", 
    834895          ] 
  • twisted/test/test_defer.py

     
    449449        else: 
    450450            self.fail("second callback failed to raise AlreadyCalledError") 
    451451 
     452class FooError(Exception): 
     453    pass 
    452454 
     455class DeferredCancellerTest(unittest.TestCase): 
     456    def setUp(self): 
     457        self.callback_results = None 
     458        self.errback_results = None 
     459        self.callback2_results = None 
     460        self.cancellerCalled = False 
     461         
     462    def _callback(self, data): 
     463        self.callback_results = data 
     464        return args[0] 
     465 
     466    def _callback2(self, data): 
     467        self.callback2_results = data 
     468 
     469    def _errback(self, data): 
     470        self.errback_results = data 
     471 
     472     
     473    def testNoCanceller(self): 
     474        # Deferred without a canceller errbacks defer.CancelledError 
     475        d=defer.Deferred() 
     476        d.addCallbacks(self._callback, self._errback) 
     477        d.cancel() 
     478        self.assertEquals(self.errback_results.type, defer.CancelledError) 
     479 
     480        # Test that further callbacks *are* swallowed 
     481        d.callback(None) 
     482 
     483        # But that a second is not 
     484        self.assertRaises(defer.AlreadyCalledError, d.callback, None) 
     485         
     486    def testCanceller(self): 
     487        def cancel(d): 
     488            self.cancellerCalled=True 
     489             
     490        d=defer.Deferred(canceller=cancel) 
     491        d.addCallbacks(self._callback, self._errback) 
     492        d.cancel() 
     493        self.assertEquals(self.cancellerCalled, True) 
     494        self.assertEquals(self.errback_results.type, defer.CancelledError) 
     495 
     496        # Test that further callbacks are *not* swallowed 
     497        self.assertRaises(defer.AlreadyCalledError, d.callback, None) 
     498         
     499    def testCancellerWithCallback(self): 
     500        # If we explicitly callback from the canceller, don't callback CancelledError 
     501        def cancel(d): 
     502            self.cancellerCalled=True 
     503            d.errback(FooError()) 
     504        d=defer.Deferred(canceller=cancel) 
     505        d.addCallbacks(self._callback, self._errback) 
     506        d.cancel() 
     507        self.assertEquals(self.cancellerCalled, True) 
     508        self.assertEquals(self.errback_results.type, FooError) 
     509 
     510    def testCancelAlreadyCalled(self): 
     511        def cancel(d): 
     512            self.cancellerCalled=True 
     513        d=defer.Deferred(canceller=cancel) 
     514        d.callback(None) 
     515        self.assertRaises(defer.AlreadyCalledError, d.cancel) 
     516        self.assertEquals(self.cancellerCalled, False) 
     517     
     518    def testCancelNestedDeferred(self): 
     519        def innerCancel(d): 
     520            self.assertIdentical(d, innerDeferred) 
     521            self.cancellerCalled=True 
     522        def cancel(d): 
     523            self.assert_(False) 
     524             
     525        innerDeferred=defer.Deferred(canceller=innerCancel) 
     526        d=defer.Deferred(canceller=cancel) 
     527        d.callback(None) 
     528        d.addCallback(lambda data: innerDeferred) 
     529        d.cancel() 
     530        d.addCallbacks(self._callback, self._errback) 
     531        self.assertEquals(self.cancellerCalled, True) 
     532        self.assertEquals(self.errback_results.type, defer.CancelledError) 
     533 
    453534class LogTestCase(unittest.TestCase): 
    454535 
    455536    def setUp(self): 
     
    539620        self.failUnless(lock.locked) 
    540621        self.assertEquals(self.counter, 3) 
    541622 
     623        d = lock.acquire().addBoth(lambda x: setattr(self, 'result', x)) 
     624        d.cancel() 
     625        self.assertEquals(self.result.type, defer.CancelledError) 
     626         
    542627        lock.release() 
    543628        self.failIf(lock.locked) 
    544629 
     
    567652            sem.acquire().addCallback(self._incr) 
    568653            self.assertEquals(self.counter, i) 
    569654 
     655 
     656        success = [] 
     657        def fail(r): 
     658            success.append(False) 
     659        def succeed(r): 
     660            success.append(True) 
     661        d = sem.acquire().addCallbacks(fail, succeed) 
     662        d.cancel() 
     663        self.assertEquals(success, [True]) 
     664         
    570665        sem.acquire().addCallback(self._incr) 
    571666        self.assertEquals(self.counter, N) 
    572667 
    573668        sem.release() 
    574669        self.assertEquals(self.counter, N + 1) 
    575  
     670         
    576671        for i in range(1, 1 + N): 
    577672            sem.release() 
    578673            self.assertEquals(self.counter, N + 1) 
     
    614709        queue = defer.DeferredQueue(backlog=0) 
    615710        self.assertRaises(defer.QueueUnderflow, queue.get) 
    616711 
     712        queue = defer.DeferredQueue(size=0) 
     713 
     714        success = [] 
     715        def fail(r): 
     716            success.append(False) 
     717        def succeed(r): 
     718            success.append(True) 
     719        d = queue.get().addCallbacks(fail, succeed) 
     720        d.cancel() 
     721        self.assertEquals(success, [True]) 
     722        self.assertRaises(defer.QueueOverflow, queue.put, None)