root/trunk/twisted/trial/unittest.py

Revision 32905, 57.2 KB (checked in by exarkun, 7 months ago)

Merge assertdictequal-5291

Author: exarkun
Reviewer: therve
Fixes: #5291

Add a message parameter to failUnlessIsInstance (aka assertIsInstance).
This lets application code augment the exception message when the assertion
fails. It also makes the signature compatible with the stdlib unitttest
version, which lets other stdlib unittest code which uses
assertIsInstance work properly - primarily assertDictEqual.

Line 
1# -*- test-case-name: twisted.trial.test.test_tests -*-
2# Copyright (c) Twisted Matrix Laboratories.
3# See LICENSE for details.
4
5"""
6Things likely to be used by writers of unit tests.
7
8Maintainer: Jonathan Lange
9"""
10
11
12import doctest, inspect
13import os, warnings, sys, tempfile, gc, types
14from pprint import pformat
15from dis import findlinestarts as _findlinestarts
16
17from twisted.internet import defer, utils
18from twisted.python import components, failure, log, monkey
19from twisted.python.deprecate import getDeprecationWarningString
20
21from twisted.trial import itrial, reporter, util
22
23pyunit = __import__('unittest')
24
25from zope.interface import implements
26
27
28
29class SkipTest(Exception):
30    """
31    Raise this (with a reason) to skip the current test. You may also set
32    method.skip to a reason string to skip it, or set class.skip to skip the
33    entire TestCase.
34    """
35
36
37class FailTest(AssertionError):
38    """Raised to indicate the current test has failed to pass."""
39
40
41class Todo(object):
42    """
43    Internal object used to mark a L{TestCase} as 'todo'. Tests marked 'todo'
44    are reported differently in Trial L{TestResult}s. If todo'd tests fail,
45    they do not fail the suite and the errors are reported in a separate
46    category. If todo'd tests succeed, Trial L{TestResult}s will report an
47    unexpected success.
48    """
49
50    def __init__(self, reason, errors=None):
51        """
52        @param reason: A string explaining why the test is marked 'todo'
53
54        @param errors: An iterable of exception types that the test is
55        expected to raise. If one of these errors is raised by the test, it
56        will be trapped. Raising any other kind of error will fail the test.
57        If C{None} is passed, then all errors will be trapped.
58        """
59        self.reason = reason
60        self.errors = errors
61
62    def __repr__(self):
63        return "<Todo reason=%r errors=%r>" % (self.reason, self.errors)
64
65    def expected(self, failure):
66        """
67        @param failure: A L{twisted.python.failure.Failure}.
68
69        @return: C{True} if C{failure} is expected, C{False} otherwise.
70        """
71        if self.errors is None:
72            return True
73        for error in self.errors:
74            if failure.check(error):
75                return True
76        return False
77
78
79def makeTodo(value):
80    """
81    Return a L{Todo} object built from C{value}.
82
83    If C{value} is a string, return a Todo that expects any exception with
84    C{value} as a reason. If C{value} is a tuple, the second element is used
85    as the reason and the first element as the excepted error(s).
86
87    @param value: A string or a tuple of C{(errors, reason)}, where C{errors}
88    is either a single exception class or an iterable of exception classes.
89
90    @return: A L{Todo} object.
91    """
92    if isinstance(value, str):
93        return Todo(reason=value)
94    if isinstance(value, tuple):
95        errors, reason = value
96        try:
97            errors = list(errors)
98        except TypeError:
99            errors = [errors]
100        return Todo(reason=reason, errors=errors)
101
102
103
104class _Warning(object):
105    """
106    A L{_Warning} instance represents one warning emitted through the Python
107    warning system (L{warnings}).  This is used to insulate callers of
108    L{_collectWarnings} from changes to the Python warnings system which might
109    otherwise require changes to the warning objects that function passes to
110    the observer object it accepts.
111
112    @ivar message: The string which was passed as the message parameter to
113        L{warnings.warn}.
114
115    @ivar category: The L{Warning} subclass which was passed as the category
116        parameter to L{warnings.warn}.
117
118    @ivar filename: The name of the file containing the definition of the code
119        object which was C{stacklevel} frames above the call to
120        L{warnings.warn}, where C{stacklevel} is the value of the C{stacklevel}
121        parameter passed to L{warnings.warn}.
122
123    @ivar lineno: The source line associated with the active instruction of the
124        code object object which was C{stacklevel} frames above the call to
125        L{warnings.warn}, where C{stacklevel} is the value of the C{stacklevel}
126        parameter passed to L{warnings.warn}.
127    """
128    def __init__(self, message, category, filename, lineno):
129        self.message = message
130        self.category = category
131        self.filename = filename
132        self.lineno = lineno
133
134
135def _setWarningRegistryToNone(modules):
136    """
137    Disable the per-module cache for every module found in C{modules}, typically
138    C{sys.modules}.
139
140    @param modules: Dictionary of modules, typically sys.module dict
141    """
142    for v in modules.values():
143        if v is not None:
144            try:
145                v.__warningregistry__ = None
146            except:
147                # Don't specify a particular exception type to handle in case
148                # some wacky object raises some wacky exception in response to
149                # the setattr attempt.
150                pass
151
152
153def _collectWarnings(observeWarning, f, *args, **kwargs):
154    """
155    Call C{f} with C{args} positional arguments and C{kwargs} keyword arguments
156    and collect all warnings which are emitted as a result in a list.
157
158    @param observeWarning: A callable which will be invoked with a L{_Warning}
159        instance each time a warning is emitted.
160
161    @return: The return value of C{f(*args, **kwargs)}.
162    """
163    def showWarning(message, category, filename, lineno, file=None, line=None):
164        assert isinstance(message, Warning)
165        observeWarning(_Warning(
166                message.args[0], category, filename, lineno))
167
168    # Disable the per-module cache for every module otherwise if the warning
169    # which the caller is expecting us to collect was already emitted it won't
170    # be re-emitted by the call to f which happens below.
171    _setWarningRegistryToNone(sys.modules)
172
173    origFilters = warnings.filters[:]
174    origShow = warnings.showwarning
175    warnings.simplefilter('always')
176    try:
177        warnings.showwarning = showWarning
178        result = f(*args, **kwargs)
179    finally:
180        warnings.filters[:] = origFilters
181        warnings.showwarning = origShow
182    return result
183
184
185
186class _Assertions(pyunit.TestCase, object):
187    """
188    Replaces many of the built-in TestCase assertions. In general, these
189    assertions provide better error messages and are easier to use in
190    callbacks. Also provides new assertions such as L{failUnlessFailure}.
191
192    Although the tests are defined as 'failIf*' and 'failUnless*', they can
193    also be called as 'assertNot*' and 'assert*'.
194    """
195
196    def fail(self, msg=None):
197        """
198        Absolutely fail the test.  Do not pass go, do not collect $200.
199
200        @param msg: the message that will be displayed as the reason for the
201        failure
202        """
203        raise self.failureException(msg)
204
205    def failIf(self, condition, msg=None):
206        """
207        Fail the test if C{condition} evaluates to True.
208
209        @param condition: any object that defines __nonzero__
210        """
211        if condition:
212            raise self.failureException(msg)
213        return condition
214    assertNot = assertFalse = failUnlessFalse = failIf
215
216    def failUnless(self, condition, msg=None):
217        """
218        Fail the test if C{condition} evaluates to False.
219
220        @param condition: any object that defines __nonzero__
221        """
222        if not condition:
223            raise self.failureException(msg)
224        return condition
225    assert_ = assertTrue = failUnlessTrue = failUnless
226
227    def failUnlessRaises(self, exception, f, *args, **kwargs):
228        """
229        Fail the test unless calling the function C{f} with the given
230        C{args} and C{kwargs} raises C{exception}. The failure will report
231        the traceback and call stack of the unexpected exception.
232
233        @param exception: exception type that is to be expected
234        @param f: the function to call
235
236        @return: The raised exception instance, if it is of the given type.
237        @raise self.failureException: Raised if the function call does
238            not raise an exception or if it raises an exception of a
239            different type.
240        """
241        try:
242            result = f(*args, **kwargs)
243        except exception, inst:
244            return inst
245        except:
246            raise self.failureException('%s raised instead of %s:\n %s'
247                                        % (sys.exc_info()[0],
248                                           exception.__name__,
249                                           failure.Failure().getTraceback()))
250        else:
251            raise self.failureException('%s not raised (%r returned)'
252                                        % (exception.__name__, result))
253    assertRaises = failUnlessRaises
254
255
256    def assertEqual(self, first, second, msg=''):
257        """
258        Fail the test if C{first} and C{second} are not equal.
259
260        @param msg: A string describing the failure that's included in the
261            exception.
262        """
263        if not first == second:
264            if msg is None:
265                msg = ''
266            if len(msg) > 0:
267                msg += '\n'
268            raise self.failureException(
269                '%snot equal:\na = %s\nb = %s\n'
270                % (msg, pformat(first), pformat(second)))
271        return first
272    failUnlessEqual = failUnlessEquals = assertEquals = assertEqual
273
274
275    def failUnlessIdentical(self, first, second, msg=None):
276        """
277        Fail the test if C{first} is not C{second}.  This is an
278        obect-identity-equality test, not an object equality
279        (i.e. C{__eq__}) test.
280
281        @param msg: if msg is None, then the failure message will be
282        '%r is not %r' % (first, second)
283        """
284        if first is not second:
285            raise self.failureException(msg or '%r is not %r' % (first, second))
286        return first
287    assertIdentical = failUnlessIdentical
288
289    def failIfIdentical(self, first, second, msg=None):
290        """
291        Fail the test if C{first} is C{second}.  This is an
292        obect-identity-equality test, not an object equality
293        (i.e. C{__eq__}) test.
294
295        @param msg: if msg is None, then the failure message will be
296        '%r is %r' % (first, second)
297        """
298        if first is second:
299            raise self.failureException(msg or '%r is %r' % (first, second))
300        return first
301    assertNotIdentical = failIfIdentical
302
303    def failIfEqual(self, first, second, msg=None):
304        """
305        Fail the test if C{first} == C{second}.
306
307        @param msg: if msg is None, then the failure message will be
308        '%r == %r' % (first, second)
309        """
310        if not first != second:
311            raise self.failureException(msg or '%r == %r' % (first, second))
312        return first
313    assertNotEqual = assertNotEquals = failIfEquals = failIfEqual
314
315    def failUnlessIn(self, containee, container, msg=None):
316        """
317        Fail the test if C{containee} is not found in C{container}.
318
319        @param containee: the value that should be in C{container}
320        @param container: a sequence type, or in the case of a mapping type,
321                          will follow semantics of 'if key in dict.keys()'
322        @param msg: if msg is None, then the failure message will be
323                    '%r not in %r' % (first, second)
324        """
325        if containee not in container:
326            raise self.failureException(msg or "%r not in %r"
327                                        % (containee, container))
328        return containee
329    assertIn = failUnlessIn
330
331    def failIfIn(self, containee, container, msg=None):
332        """
333        Fail the test if C{containee} is found in C{container}.
334
335        @param containee: the value that should not be in C{container}
336        @param container: a sequence type, or in the case of a mapping type,
337                          will follow semantics of 'if key in dict.keys()'
338        @param msg: if msg is None, then the failure message will be
339                    '%r in %r' % (first, second)
340        """
341        if containee in container:
342            raise self.failureException(msg or "%r in %r"
343                                        % (containee, container))
344        return containee
345    assertNotIn = failIfIn
346
347    def failIfAlmostEqual(self, first, second, places=7, msg=None):
348        """
349        Fail if the two objects are equal as determined by their
350        difference rounded to the given number of decimal places
351        (default 7) and comparing to zero.
352
353        @note: decimal places (from zero) is usually not the same
354               as significant digits (measured from the most
355               signficant digit).
356
357        @note: included for compatiblity with PyUnit test cases
358        """
359        if round(second-first, places) == 0:
360            raise self.failureException(msg or '%r == %r within %r places'
361                                        % (first, second, places))
362        return first
363    assertNotAlmostEqual = assertNotAlmostEquals = failIfAlmostEqual
364    failIfAlmostEquals = failIfAlmostEqual
365
366    def failUnlessAlmostEqual(self, first, second, places=7, msg=None):
367        """
368        Fail if the two objects are unequal as determined by their
369        difference rounded to the given number of decimal places
370        (default 7) and comparing to zero.
371
372        @note: decimal places (from zero) is usually not the same
373               as significant digits (measured from the most
374               signficant digit).
375
376        @note: included for compatiblity with PyUnit test cases
377        """
378        if round(second-first, places) != 0:
379            raise self.failureException(msg or '%r != %r within %r places'
380                                        % (first, second, places))
381        return first
382    assertAlmostEqual = assertAlmostEquals = failUnlessAlmostEqual
383    failUnlessAlmostEquals = failUnlessAlmostEqual
384
385    def failUnlessApproximates(self, first, second, tolerance, msg=None):
386        """
387        Fail if C{first} - C{second} > C{tolerance}
388
389        @param msg: if msg is None, then the failure message will be
390                    '%r ~== %r' % (first, second)
391        """
392        if abs(first - second) > tolerance:
393            raise self.failureException(msg or "%s ~== %s" % (first, second))
394        return first
395    assertApproximates = failUnlessApproximates
396
397    def failUnlessFailure(self, deferred, *expectedFailures):
398        """
399        Fail if C{deferred} does not errback with one of C{expectedFailures}.
400        Returns the original Deferred with callbacks added. You will need
401        to return this Deferred from your test case.
402        """
403        def _cb(ignore):
404            raise self.failureException(
405                "did not catch an error, instead got %r" % (ignore,))
406
407        def _eb(failure):
408            if failure.check(*expectedFailures):
409                return failure.value
410            else:
411                output = ('\nExpected: %r\nGot:\n%s'
412                          % (expectedFailures, str(failure)))
413                raise self.failureException(output)
414        return deferred.addCallbacks(_cb, _eb)
415    assertFailure = failUnlessFailure
416
417    def failUnlessSubstring(self, substring, astring, msg=None):
418        """
419        Fail if C{substring} does not exist within C{astring}.
420        """
421        return self.failUnlessIn(substring, astring, msg)
422    assertSubstring = failUnlessSubstring
423
424    def failIfSubstring(self, substring, astring, msg=None):
425        """
426        Fail if C{astring} contains C{substring}.
427        """
428        return self.failIfIn(substring, astring, msg)
429    assertNotSubstring = failIfSubstring
430
431    def failUnlessWarns(self, category, message, filename, f,
432                       *args, **kwargs):
433        """
434        Fail if the given function doesn't generate the specified warning when
435        called. It calls the function, checks the warning, and forwards the
436        result of the function if everything is fine.
437
438        @param category: the category of the warning to check.
439        @param message: the output message of the warning to check.
440        @param filename: the filename where the warning should come from.
441        @param f: the function which is supposed to generate the warning.
442        @type f: any callable.
443        @param args: the arguments to C{f}.
444        @param kwargs: the keywords arguments to C{f}.
445
446        @return: the result of the original function C{f}.
447        """
448        warningsShown = []
449        result = _collectWarnings(warningsShown.append, f, *args, **kwargs)
450
451        if not warningsShown:
452            self.fail("No warnings emitted")
453        first = warningsShown[0]
454        for other in warningsShown[1:]:
455            if ((other.message, other.category)
456                != (first.message, first.category)):
457                self.fail("Can't handle different warnings")
458        self.assertEqual(first.message, message)
459        self.assertIdentical(first.category, category)
460
461        # Use starts with because of .pyc/.pyo issues.
462        self.failUnless(
463            filename.startswith(first.filename),
464            'Warning in %r, expected %r' % (first.filename, filename))
465
466        # It would be nice to be able to check the line number as well, but
467        # different configurations actually end up reporting different line
468        # numbers (generally the variation is only 1 line, but that's enough
469        # to fail the test erroneously...).
470        # self.assertEqual(lineno, xxx)
471
472        return result
473    assertWarns = failUnlessWarns
474
475    def failUnlessIsInstance(self, instance, classOrTuple, message=None):
476        """
477        Fail if C{instance} is not an instance of the given class or of
478        one of the given classes.
479
480        @param instance: the object to test the type (first argument of the
481            C{isinstance} call).
482        @type instance: any.
483        @param classOrTuple: the class or classes to test against (second
484            argument of the C{isinstance} call).
485        @type classOrTuple: class, type, or tuple.
486
487        @param message: Custom text to include in the exception text if the
488            assertion fails.
489        """
490        if not isinstance(instance, classOrTuple):
491            if message is None:
492                suffix = ""
493            else:
494                suffix = ": " + message
495            self.fail("%r is not an instance of %s%s" % (
496                    instance, classOrTuple, suffix))
497    assertIsInstance = failUnlessIsInstance
498
499    def failIfIsInstance(self, instance, classOrTuple):
500        """
501        Fail if C{instance} is not an instance of the given class or of
502        one of the given classes.
503
504        @param instance: the object to test the type (first argument of the
505            C{isinstance} call).
506        @type instance: any.
507        @param classOrTuple: the class or classes to test against (second
508            argument of the C{isinstance} call).
509        @type classOrTuple: class, type, or tuple.
510        """
511        if isinstance(instance, classOrTuple):
512            self.fail("%r is an instance of %s" % (instance, classOrTuple))
513    assertNotIsInstance = failIfIsInstance
514
515
516class _LogObserver(object):
517    """
518    Observes the Twisted logs and catches any errors.
519
520    @ivar _errors: A C{list} of L{Failure} instances which were received as
521        error events from the Twisted logging system.
522
523    @ivar _added: A C{int} giving the number of times C{_add} has been called
524        less the number of times C{_remove} has been called; used to only add
525        this observer to the Twisted logging since once, regardless of the
526        number of calls to the add method.
527
528    @ivar _ignored: A C{list} of exception types which will not be recorded.
529    """
530
531    def __init__(self):
532        self._errors = []
533        self._added = 0
534        self._ignored = []
535
536
537    def _add(self):
538        if self._added == 0:
539            log.addObserver(self.gotEvent)
540            self._oldFE, log._flushErrors = (log._flushErrors, self.flushErrors)
541            self._oldIE, log._ignore = (log._ignore, self._ignoreErrors)
542            self._oldCI, log._clearIgnores = (log._clearIgnores,
543                                              self._clearIgnores)
544        self._added += 1
545
546    def _remove(self):
547        self._added -= 1
548        if self._added == 0:
549            log.removeObserver(self.gotEvent)
550            log._flushErrors = self._oldFE
551            log._ignore = self._oldIE
552            log._clearIgnores = self._oldCI
553
554
555    def _ignoreErrors(self, *errorTypes):
556        """
557        Do not store any errors with any of the given types.
558        """
559        self._ignored.extend(errorTypes)
560
561
562    def _clearIgnores(self):
563        """
564        Stop ignoring any errors we might currently be ignoring.
565        """
566        self._ignored = []
567
568
569    def flushErrors(self, *errorTypes):
570        """
571        Flush errors from the list of caught errors. If no arguments are
572        specified, remove all errors. If arguments are specified, only remove
573        errors of those types from the stored list.
574        """
575        if errorTypes:
576            flushed = []
577            remainder = []
578            for f in self._errors:
579                if f.check(*errorTypes):
580                    flushed.append(f)
581                else:
582                    remainder.append(f)
583            self._errors = remainder
584        else:
585            flushed = self._errors
586            self._errors = []
587        return flushed
588
589
590    def getErrors(self):
591        """
592        Return a list of errors caught by this observer.
593        """
594        return self._errors
595
596
597    def gotEvent(self, event):
598        """
599        The actual observer method. Called whenever a message is logged.
600
601        @param event: A dictionary containing the log message. Actual
602        structure undocumented (see source for L{twisted.python.log}).
603        """
604        if event.get('isError', False) and 'failure' in event:
605            f = event['failure']
606            if len(self._ignored) == 0 or not f.check(*self._ignored):
607                self._errors.append(f)
608
609
610
611_logObserver = _LogObserver()
612
613_wait_is_running = []
614
615class TestCase(_Assertions):
616    """
617    A unit test. The atom of the unit testing universe.
618
619    This class extends C{unittest.TestCase} from the standard library. The
620    main feature is the ability to return C{Deferred}s from tests and fixture
621    methods and to have the suite wait for those C{Deferred}s to fire.
622
623    To write a unit test, subclass C{TestCase} and define a method (say,
624    'test_foo') on the subclass. To run the test, instantiate your subclass
625    with the name of the method, and call L{run} on the instance, passing a
626    L{TestResult} object.
627
628    The C{trial} script will automatically find any C{TestCase} subclasses
629    defined in modules beginning with 'test_' and construct test cases for all
630    methods beginning with 'test'.
631
632    If an error is logged during the test run, the test will fail with an
633    error. See L{log.err}.
634
635    @ivar failureException: An exception class, defaulting to C{FailTest}. If
636    the test method raises this exception, it will be reported as a failure,
637    rather than an exception. All of the assertion methods raise this if the
638    assertion fails.
639
640    @ivar skip: C{None} or a string explaining why this test is to be
641    skipped. If defined, the test will not be run. Instead, it will be
642    reported to the result object as 'skipped' (if the C{TestResult} supports
643    skipping).
644
645    @ivar suppress: C{None} or a list of tuples of C{(args, kwargs)} to be
646    passed to C{warnings.filterwarnings}. Use these to suppress warnings
647    raised in a test. Useful for testing deprecated code. See also
648    L{util.suppress}.
649
650    @ivar timeout: A real number of seconds. If set, the test will
651    raise an error if it takes longer than C{timeout} seconds.
652    If not set, util.DEFAULT_TIMEOUT_DURATION is used.
653
654    @ivar todo: C{None}, a string or a tuple of C{(errors, reason)} where
655    C{errors} is either an exception class or an iterable of exception
656    classes, and C{reason} is a string. See L{Todo} or L{makeTodo} for more
657    information.
658    """
659
660    implements(itrial.ITestCase)
661    failureException = FailTest
662
663    def __init__(self, methodName='runTest'):
664        """
665        Construct an asynchronous test case for C{methodName}.
666
667        @param methodName: The name of a method on C{self}. This method should
668        be a unit test. That is, it should be a short method that calls some of
669        the assert* methods. If C{methodName} is unspecified, L{runTest} will
670        be used as the test method. This is mostly useful for testing Trial.
671        """
672        super(TestCase, self).__init__(methodName)
673        self._testMethodName = methodName
674        testMethod = getattr(self, methodName)
675        self._parents = [testMethod, self]
676        self._parents.extend(util.getPythonContainers(testMethod))
677        self._passed = False
678        self._cleanups = []
679
680    if sys.version_info >= (2, 6):
681        # Override the comparison defined by the base TestCase which considers
682        # instances of the same class with the same _testMethodName to be
683        # equal.  Since trial puts TestCase instances into a set, that
684        # definition of comparison makes it impossible to run the same test
685        # method twice.  Most likely, trial should stop using a set to hold
686        # tests, but until it does, this is necessary on Python 2.6.  Only
687        # __eq__ and __ne__ are required here, not __hash__, since the
688        # inherited __hash__ is compatible with these equality semantics.  A
689        # different __hash__ might be slightly more efficient (by reducing
690        # collisions), but who cares? -exarkun
691        def __eq__(self, other):
692            return self is other
693
694        def __ne__(self, other):
695            return self is not other
696
697
698    def _run(self, methodName, result):
699        from twisted.internet import reactor
700        timeout = self.getTimeout()
701        def onTimeout(d):
702            e = defer.TimeoutError("%r (%s) still running at %s secs"
703                % (self, methodName, timeout))
704            f = failure.Failure(e)
705            # try to errback the deferred that the test returns (for no gorram
706            # reason) (see issue1005 and test_errorPropagation in
707            # test_deferred)
708            try:
709                d.errback(f)
710            except defer.AlreadyCalledError:
711                # if the deferred has been called already but the *back chain
712                # is still unfinished, crash the reactor and report timeout
713                # error ourself.
714                reactor.crash()
715                self._timedOut = True # see self._wait
716                todo = self.getTodo()
717                if todo is not None and todo.expected(f):
718                    result.addExpectedFailure(self, f, todo)
719                else:
720                    result.addError(self, f)
721        onTimeout = utils.suppressWarnings(
722            onTimeout, util.suppress(category=DeprecationWarning))
723        method = getattr(self, methodName)
724        d = defer.maybeDeferred(utils.runWithWarningsSuppressed,
725                                self.getSuppress(), method)
726        call = reactor.callLater(timeout, onTimeout, d)
727        d.addBoth(lambda x : call.active() and call.cancel() or x)
728        return d
729
730    def shortDescription(self):
731        desc = super(TestCase, self).shortDescription()
732        if desc is None:
733            return self._testMethodName
734        return desc
735
736    def __call__(self, *args, **kwargs):
737        return self.run(*args, **kwargs)
738
739    def deferSetUp(self, ignored, result):
740        d = self._run('setUp', result)
741        d.addCallbacks(self.deferTestMethod, self._ebDeferSetUp,
742                       callbackArgs=(result,),
743                       errbackArgs=(result,))
744        return d
745
746    def _ebDeferSetUp(self, failure, result):
747        if failure.check(SkipTest):
748            result.addSkip(self, self._getReason(failure))
749        else:
750            result.addError(self, failure)
751            if failure.check(KeyboardInterrupt):
752                result.stop()
753        return self.deferRunCleanups(None, result)
754
755    def deferTestMethod(self, ignored, result):
756        d = self._run(self._testMethodName, result)
757        d.addCallbacks(self._cbDeferTestMethod, self._ebDeferTestMethod,
758                       callbackArgs=(result,),
759                       errbackArgs=(result,))
760        d.addBoth(self.deferRunCleanups, result)
761        d.addBoth(self.deferTearDown, result)
762        return d
763
764    def _cbDeferTestMethod(self, ignored, result):
765        if self.getTodo() is not None:
766            result.addUnexpectedSuccess(self, self.getTodo())
767        else:
768            self._passed = True
769        return ignored
770
771    def _ebDeferTestMethod(self, f, result):
772        todo = self.getTodo()
773        if todo is not None and todo.expected(f):
774            result.addExpectedFailure(self, f, todo)
775        elif f.check(self.failureException, FailTest):
776            result.addFailure(self, f)
777        elif f.check(KeyboardInterrupt):
778            result.addError(self, f)
779            result.stop()
780        elif f.check(SkipTest):
781            result.addSkip(self, self._getReason(f))
782        else:
783            result.addError(self, f)
784
785    def deferTearDown(self, ignored, result):
786        d = self._run('tearDown', result)
787        d.addErrback(self._ebDeferTearDown, result)
788        return d
789
790    def _ebDeferTearDown(self, failure, result):
791        result.addError(self, failure)
792        if failure.check(KeyboardInterrupt):
793            result.stop()
794        self._passed = False
795
796    def deferRunCleanups(self, ignored, result):
797        """
798        Run any scheduled cleanups and report errors (if any to the result
799        object.
800        """
801        d = self._runCleanups()
802        d.addCallback(self._cbDeferRunCleanups, result)
803        return d
804
805    def _cbDeferRunCleanups(self, cleanupResults, result):
806        for flag, failure in cleanupResults:
807            if flag == defer.FAILURE:
808                result.addError(self, failure)
809                if failure.check(KeyboardInterrupt):
810                    result.stop()
811                self._passed = False
812
813    def _cleanUp(self, result):
814        try:
815            clean = util._Janitor(self, result).postCaseCleanup()
816            if not clean:
817                self._passed = False
818        except:
819            result.addError(self, failure.Failure())
820            self._passed = False
821        for error in self._observer.getErrors():
822            result.addError(self, error)
823            self._passed = False
824        self.flushLoggedErrors()
825        self._removeObserver()
826        if self._passed:
827            result.addSuccess(self)
828
829    def _classCleanUp(self, result):
830        try:
831            util._Janitor(self, result).postClassCleanup()
832        except:
833            result.addError(self, failure.Failure())
834
835    def _makeReactorMethod(self, name):
836        """
837        Create a method which wraps the reactor method C{name}. The new
838        method issues a deprecation warning and calls the original.
839        """
840        def _(*a, **kw):
841            warnings.warn("reactor.%s cannot be used inside unit tests. "
842                          "In the future, using %s will fail the test and may "
843                          "crash or hang the test run."
844                          % (name, name),
845                          stacklevel=2, category=DeprecationWarning)
846            return self._reactorMethods[name](*a, **kw)
847        return _
848
849    def _deprecateReactor(self, reactor):
850        """
851        Deprecate C{iterate}, C{crash} and C{stop} on C{reactor}. That is,
852        each method is wrapped in a function that issues a deprecation
853        warning, then calls the original.
854
855        @param reactor: The Twisted reactor.
856        """
857        self._reactorMethods = {}
858        for name in ['crash', 'iterate', 'stop']:
859            self._reactorMethods[name] = getattr(reactor, name)
860            setattr(reactor, name, self._makeReactorMethod(name))
861
862    def _undeprecateReactor(self, reactor):
863        """
864        Restore the deprecated reactor methods. Undoes what
865        L{_deprecateReactor} did.
866
867        @param reactor: The Twisted reactor.
868        """
869        for name, method in self._reactorMethods.iteritems():
870            setattr(reactor, name, method)
871        self._reactorMethods = {}
872
873    def _installObserver(self):
874        self._observer = _logObserver
875        self._observer._add()
876
877    def _removeObserver(self):
878        self._observer._remove()
879
880    def flushLoggedErrors(self, *errorTypes):
881        """
882        Remove stored errors received from the log.
883
884        C{TestCase} stores each error logged during the run of the test and
885        reports them as errors during the cleanup phase (after C{tearDown}).
886
887        @param *errorTypes: If unspecifed, flush all errors. Otherwise, only
888        flush errors that match the given types.
889
890        @return: A list of failures that have been removed.
891        """
892        return self._observer.flushErrors(*errorTypes)
893
894
895    def flushWarnings(self, offendingFunctions=None):
896        """
897        Remove stored warnings from the list of captured warnings and return
898        them.
899
900        @param offendingFunctions: If C{None}, all warnings issued during the
901            currently running test will be flushed.  Otherwise, only warnings
902            which I{point} to a function included in this list will be flushed.
903            All warnings include a filename and source line number; if these
904            parts of a warning point to a source line which is part of a
905            function, then the warning I{points} to that function.
906        @type offendingFunctions: L{NoneType} or L{list} of functions or methods.
907
908        @raise ValueError: If C{offendingFunctions} is not C{None} and includes
909            an object which is not a L{FunctionType} or L{MethodType} instance.
910
911        @return: A C{list}, each element of which is a C{dict} giving
912            information about one warning which was flushed by this call.  The
913            keys of each C{dict} are:
914
915                - C{'message'}: The string which was passed as the I{message}
916                  parameter to L{warnings.warn}.
917
918                - C{'category'}: The warning subclass which was passed as the
919                  I{category} parameter to L{warnings.warn}.
920
921                - C{'filename'}: The name of the file containing the definition
922                  of the code object which was C{stacklevel} frames above the
923                  call to L{warnings.warn}, where C{stacklevel} is the value of
924                  the C{stacklevel} parameter passed to L{warnings.warn}.
925
926                - C{'lineno'}: The source line associated with the active
927                  instruction of the code object object which was C{stacklevel}
928                  frames above the call to L{warnings.warn}, where
929                  C{stacklevel} is the value of the C{stacklevel} parameter
930                  passed to L{warnings.warn}.
931        """
932        if offendingFunctions is None:
933            toFlush = self._warnings[:]
934            self._warnings[:] = []
935        else:
936            toFlush = []
937            for aWarning in self._warnings:
938                for aFunction in offendingFunctions:
939                    if not isinstance(aFunction, (
940                            types.FunctionType, types.MethodType)):
941                        raise ValueError("%r is not a function or method" % (
942                                aFunction,))
943
944                    # inspect.getabsfile(aFunction) sometimes returns a
945                    # filename which disagrees with the filename the warning
946                    # system generates.  This seems to be because a
947                    # function's code object doesn't deal with source files
948                    # being renamed.  inspect.getabsfile(module) seems
949                    # better (or at least agrees with the warning system
950                    # more often), and does some normalization for us which
951                    # is desirable.  inspect.getmodule() is attractive, but
952                    # somewhat broken in Python < 2.6.  See Python bug 4845.
953                    aModule = sys.modules[aFunction.__module__]
954                    filename = inspect.getabsfile(aModule)
955
956                    if filename != os.path.normcase(aWarning.filename):
957                        continue
958                    lineStarts = list(_findlinestarts(aFunction.func_code))
959                    first = lineStarts[0][1]
960                    last = lineStarts[-1][1]
961                    if not (first <= aWarning.lineno <= last):
962                        continue
963                    # The warning points to this function, flush it and move on
964                    # to the next warning.
965                    toFlush.append(aWarning)
966                    break
967            # Remove everything which is being flushed.
968            map(self._warnings.remove, toFlush)
969
970        return [
971            {'message': w.message, 'category': w.category,
972             'filename': w.filename, 'lineno': w.lineno}
973            for w in toFlush]
974
975
976    def addCleanup(self, f, *args, **kwargs):
977        """
978        Add the given function to a list of functions to be called after the
979        test has run, but before C{tearDown}.
980
981        Functions will be run in reverse order of being added. This helps
982        ensure that tear down complements set up.
983
984        The function C{f} may return a Deferred. If so, C{TestCase} will wait
985        until the Deferred has fired before proceeding to the next function.
986        """
987        self._cleanups.append((f, args, kwargs))
988
989
990    def callDeprecated(self, version, f, *args, **kwargs):
991        """
992        Call a function that should have been deprecated at a specific version
993        and in favor of a specific alternative, and assert that it was thusly
994        deprecated.
995
996        @param version: A 2-sequence of (since, replacement), where C{since} is
997            a the first L{version<twisted.python.versions.Version>} that C{f}
998            should have been deprecated since, and C{replacement} is a suggested
999            replacement for the deprecated functionality, as described by
1000            L{twisted.python.deprecate.deprecated}.  If there is no suggested
1001            replacement, this parameter may also be simply a
1002            L{version<twisted.python.versions.Version>} by itself.
1003
1004        @param f: The deprecated function to call.
1005
1006        @param args: The arguments to pass to C{f}.
1007
1008        @param kwargs: The keyword arguments to pass to C{f}.
1009
1010        @return: Whatever C{f} returns.
1011
1012        @raise: Whatever C{f} raises.  If any exception is
1013            raised by C{f}, though, no assertions will be made about emitted
1014            deprecations.
1015
1016        @raise FailTest: if no warnings were emitted by C{f}, or if the
1017            L{DeprecationWarning} emitted did not produce the canonical
1018            please-use-something-else message that is standard for Twisted
1019            deprecations according to the given version and replacement.
1020        """
1021        result = f(*args, **kwargs)
1022        warningsShown = self.flushWarnings([self.callDeprecated])
1023        try:
1024            info = list(version)
1025        except TypeError:
1026            since = version
1027            replacement = None
1028        else:
1029            [since, replacement] = info
1030
1031        if len(warningsShown) == 0:
1032            self.fail('%r is not deprecated.' % (f,))
1033
1034        observedWarning = warningsShown[0]['message']
1035        expectedWarning = getDeprecationWarningString(
1036            f, since, replacement=replacement)
1037        self.assertEqual(expectedWarning, observedWarning)
1038
1039        return result
1040
1041
1042    def _runCleanups(self):
1043        """
1044        Run the cleanups added with L{addCleanup} in order.
1045
1046        @return: A C{Deferred} that fires when all cleanups are run.
1047        """
1048        def _makeFunction(f, args, kwargs):
1049            return lambda: f(*args, **kwargs)
1050        callables = []
1051        while len(self._cleanups) > 0:
1052            f, args, kwargs = self._cleanups.pop()
1053            callables.append(_makeFunction(f, args, kwargs))
1054        return util._runSequentially(callables)
1055
1056
1057    def patch(self, obj, attribute, value):
1058        """
1059        Monkey patch an object for the duration of the test.
1060
1061        The monkey patch will be reverted at the end of the test using the
1062        L{addCleanup} mechanism.
1063
1064        The L{MonkeyPatcher} is returned so that users can restore and
1065        re-apply the monkey patch within their tests.
1066
1067        @param obj: The object to monkey patch.
1068        @param attribute: The name of the attribute to change.
1069        @param value: The value to set the attribute to.
1070        @return: A L{monkey.MonkeyPatcher} object.
1071        """
1072        monkeyPatch = monkey.MonkeyPatcher((obj, attribute, value))
1073        monkeyPatch.patch()
1074        self.addCleanup(monkeyPatch.restore)
1075        return monkeyPatch
1076
1077
1078    def runTest(self):
1079        """
1080        If no C{methodName} argument is passed to the constructor, L{run} will
1081        treat this method as the thing with the actual test inside.
1082        """
1083
1084
1085    def run(self, result):
1086        """
1087        Run the test case, storing the results in C{result}.
1088
1089        First runs C{setUp} on self, then runs the test method (defined in the
1090        constructor), then runs C{tearDown}. Any of these may return
1091        L{Deferred}s. After they complete, does some reactor cleanup.
1092
1093        @param result: A L{TestResult} object.
1094        """
1095        log.msg("--> %s <--" % (self.id()))
1096        from twisted.internet import reactor
1097        new_result = itrial.IReporter(result, None)
1098        if new_result is None:
1099            result = PyUnitResultAdapter(result)
1100        else:
1101            result = new_result
1102        self._timedOut = False
1103        result.startTest(self)
1104        if self.getSkip(): # don't run test methods that are marked as .skip
1105            result.addSkip(self, self.getSkip())
1106            result.stopTest(self)
1107            return
1108        self._installObserver()
1109
1110        # All the code inside runThunk will be run such that warnings emitted
1111        # by it will be collected and retrievable by flushWarnings.
1112        def runThunk():
1113            self._passed = False
1114            self._deprecateReactor(reactor)
1115            try:
1116                d = self.deferSetUp(None, result)
1117                try:
1118                    self._wait(d)
1119                finally:
1120                    self._cleanUp(result)
1121                    self._classCleanUp(result)
1122            finally:
1123                self._undeprecateReactor(reactor)
1124
1125        self._warnings = []
1126        _collectWarnings(self._warnings.append, runThunk)
1127
1128        # Any collected warnings which the test method didn't flush get
1129        # re-emitted so they'll be logged or show up on stdout or whatever.
1130        for w in self.flushWarnings():
1131            try:
1132                warnings.warn_explicit(**w)
1133            except:
1134                result.addError(self, failure.Failure())
1135
1136        result.stopTest(self)
1137
1138
1139    def _getReason(self, f):
1140        if len(f.value.args) > 0:
1141            reason = f.value.args[0]
1142        else:
1143            warnings.warn(("Do not raise unittest.SkipTest with no "
1144                           "arguments! Give a reason for skipping tests!"),
1145                          stacklevel=2)
1146            reason = f
1147        return reason
1148
1149    def getSkip(self):
1150        """
1151        Return the skip reason set on this test, if any is set. Checks on the
1152        instance first, then the class, then the module, then packages. As
1153        soon as it finds something with a C{skip} attribute, returns that.
1154        Returns C{None} if it cannot find anything. See L{TestCase} docstring
1155        for more details.
1156        """
1157        return util.acquireAttribute(self._parents, 'skip', None)
1158
1159    def getTodo(self):
1160        """
1161        Return a L{Todo} object if the test is marked todo. Checks on the
1162        instance first, then the class, then the module, then packages. As
1163        soon as it finds something with a C{todo} attribute, returns that.
1164        Returns C{None} if it cannot find anything. See L{TestCase} docstring
1165        for more details.
1166        """
1167        todo = util.acquireAttribute(self._parents, 'todo', None)
1168        if todo is None:
1169            return None
1170        return makeTodo(todo)
1171
1172    def getTimeout(self):
1173        """
1174        Returns the timeout value set on this test. Checks on the instance
1175        first, then the class, then the module, then packages. As soon as it
1176        finds something with a C{timeout} attribute, returns that. Returns
1177        L{util.DEFAULT_TIMEOUT_DURATION} if it cannot find anything. See
1178        L{TestCase} docstring for more details.
1179        """
1180        timeout =  util.acquireAttribute(self._parents, 'timeout',
1181                                         util.DEFAULT_TIMEOUT_DURATION)
1182        try:
1183            return float(timeout)
1184        except (ValueError, TypeError):
1185            # XXX -- this is here because sometimes people will have methods
1186            # called 'timeout', or set timeout to 'orange', or something
1187            # Particularly, test_news.NewsTestCase and ReactorCoreTestCase
1188            # both do this.
1189            warnings.warn("'timeout' attribute needs to be a number.",
1190                          category=DeprecationWarning)
1191            return util.DEFAULT_TIMEOUT_DURATION
1192
1193    def getSuppress(self):
1194        """
1195        Returns any warning suppressions set for this test. Checks on the
1196        instance first, then the class, then the module, then packages. As
1197        soon as it finds something with a C{suppress} attribute, returns that.
1198        Returns any empty list (i.e. suppress no warnings) if it cannot find
1199        anything. See L{TestCase} docstring for more details.
1200        """
1201        return util.acquireAttribute(self._parents, 'suppress', [])
1202
1203
1204    def visit(self, visitor):
1205        """
1206        Visit this test case. Call C{visitor} with C{self} as a parameter.
1207
1208        Deprecated in Twisted 8.0.
1209
1210        @param visitor: A callable which expects a single parameter: a test
1211        case.
1212
1213        @return: None
1214        """
1215        warnings.warn("Test visitors deprecated in Twisted 8.0",
1216                      category=DeprecationWarning)
1217        visitor(self)
1218
1219
1220    def mktemp(self):
1221        """Returns a unique name that may be used as either a temporary
1222        directory or filename.
1223
1224        @note: you must call os.mkdir on the value returned from this
1225               method if you wish to use it as a directory!
1226        """
1227        MAX_FILENAME = 32 # some platforms limit lengths of filenames
1228        base = os.path.join(self.__class__.__module__[:MAX_FILENAME],
1229                            self.__class__.__name__[:MAX_FILENAME],
1230                            self._testMethodName[:MAX_FILENAME])
1231        if not os.path.exists(base):
1232            os.makedirs(base)
1233        dirname = tempfile.mkdtemp('', '', base)
1234        return os.path.join(dirname, 'temp')
1235
1236    def _wait(self, d, running=_wait_is_running):
1237        """Take a Deferred that only ever callbacks. Block until it happens.
1238        """
1239        from twisted.internet import reactor
1240        if running:
1241            raise RuntimeError("_wait is not reentrant")
1242
1243        results = []
1244        def append(any):
1245            if results is not None:
1246                results.append(any)
1247        def crash(ign):
1248            if results is not None:
1249                reactor.crash()
1250        crash = utils.suppressWarnings(
1251            crash, util.suppress(message=r'reactor\.crash cannot be used.*',
1252                                 category=DeprecationWarning))
1253        def stop():
1254            reactor.crash()
1255        stop = utils.suppressWarnings(
1256            stop, util.suppress(message=r'reactor\.crash cannot be used.*',
1257                                category=DeprecationWarning))
1258
1259        running.append(None)
1260        try:
1261            d.addBoth(append)
1262            if results:
1263                # d might have already been fired, in which case append is
1264                # called synchronously. Avoid any reactor stuff.
1265                return
1266            d.addBoth(crash)
1267            reactor.stop = stop
1268            try:
1269                reactor.run()
1270            finally:
1271                del reactor.stop
1272
1273            # If the reactor was crashed elsewhere due to a timeout, hopefully
1274            # that crasher also reported an error. Just return.
1275            # _timedOut is most likely to be set when d has fired but hasn't
1276            # completed its callback chain (see self._run)
1277            if results or self._timedOut: #defined in run() and _run()
1278                return
1279
1280            # If the timeout didn't happen, and we didn't get a result or
1281            # a failure, then the user probably aborted the test, so let's
1282            # just raise KeyboardInterrupt.
1283
1284            # FIXME: imagine this:
1285            # web/test/test_webclient.py:
1286            # exc = self.assertRaises(error.Error, wait, method(url))
1287            #
1288            # wait() will raise KeyboardInterrupt, and assertRaises will
1289            # swallow it. Therefore, wait() raising KeyboardInterrupt is
1290            # insufficient to stop trial. A suggested solution is to have
1291            # this code set a "stop trial" flag, or otherwise notify trial
1292            # that it should really try to stop as soon as possible.
1293            raise KeyboardInterrupt()
1294        finally:
1295            results = None
1296            running.pop()
1297
1298
1299class UnsupportedTrialFeature(Exception):
1300    """A feature of twisted.trial was used that pyunit cannot support."""
1301
1302
1303
1304class PyUnitResultAdapter(object):
1305    """
1306    Wrap a C{TestResult} from the standard library's C{unittest} so that it
1307    supports the extended result types from Trial, and also supports
1308    L{twisted.python.failure.Failure}s being passed to L{addError} and
1309    L{addFailure}.
1310    """
1311
1312    def __init__(self, original):
1313        """
1314        @param original: A C{TestResult} instance from C{unittest}.
1315        """
1316        self.original = original
1317
1318    def _exc_info(self, err):
1319        return util.excInfoOrFailureToExcInfo(err)
1320
1321    def startTest(self, method):
1322        self.original.startTest(method)
1323
1324    def stopTest(self, method):
1325        self.original.stopTest(method)
1326
1327    def addFailure(self, test, fail):
1328        self.original.addFailure(test, self._exc_info(fail))
1329
1330    def addError(self, test, error):
1331        self.original.addError(test, self._exc_info(error))
1332
1333    def _unsupported(self, test, feature, info):
1334        self.original.addFailure(
1335            test,
1336            (UnsupportedTrialFeature,
1337             UnsupportedTrialFeature(feature, info),
1338             None))
1339
1340    def addSkip(self, test, reason):
1341        """
1342        Report the skip as a failure.
1343        """
1344        self._unsupported(test, 'skip', reason)
1345
1346    def addUnexpectedSuccess(self, test, todo):
1347        """
1348        Report the unexpected success as a failure.
1349        """
1350        self._unsupported(test, 'unexpected success', todo)
1351
1352    def addExpectedFailure(self, test, error):
1353        """
1354        Report the expected failure (i.e. todo) as a failure.
1355        """
1356        self._unsupported(test, 'expected failure', error)
1357
1358    def addSuccess(self, test):
1359        self.original.addSuccess(test)
1360
1361    def upDownError(self, method, error, warn, printStatus):
1362        pass
1363
1364
1365
1366def suiteVisit(suite, visitor):
1367    """
1368    Visit each test in C{suite} with C{visitor}.
1369
1370    Deprecated in Twisted 8.0.
1371
1372    @param visitor: A callable which takes a single argument, the L{TestCase}
1373    instance to visit.
1374    @return: None
1375    """
1376    warnings.warn("Test visitors deprecated in Twisted 8.0",
1377                  category=DeprecationWarning)
1378    for case in suite._tests:
1379        visit = getattr(case, 'visit', None)
1380        if visit is not None:
1381            visit(visitor)
1382        elif isinstance(case, pyunit.TestCase):
1383            case = itrial.ITestCase(case)
1384            case.visit(visitor)
1385        elif isinstance(case, pyunit.TestSuite):
1386            suiteVisit(case, visitor)
1387        else:
1388            case.visit(visitor)
1389
1390
1391
1392class TestSuite(pyunit.TestSuite):
1393    """
1394    Extend the standard library's C{TestSuite} with support for the visitor
1395    pattern and a consistently overrideable C{run} method.
1396    """
1397
1398    visit = suiteVisit
1399
1400    def __call__(self, result):
1401        return self.run(result)
1402
1403
1404    def run(self, result):
1405        """
1406        Call C{run} on every member of the suite.
1407        """
1408        # we implement this because Python 2.3 unittest defines this code
1409        # in __call__, whereas 2.4 defines the code in run.
1410        for test in self._tests:
1411            if result.shouldStop:
1412                break
1413            test(result)
1414        return result
1415
1416
1417
1418class TestDecorator(components.proxyForInterface(itrial.ITestCase,
1419                                                 "_originalTest")):
1420    """
1421    Decorator for test cases.
1422
1423    @param _originalTest: The wrapped instance of test.
1424    @type _originalTest: A provider of L{itrial.ITestCase}
1425    """
1426
1427    implements(itrial.ITestCase)
1428
1429
1430    def __call__(self, result):
1431        """
1432        Run the unit test.
1433
1434        @param result: A TestResult object.
1435        """
1436        return self.run(result)
1437
1438
1439    def run(self, result):
1440        """
1441        Run the unit test.
1442
1443        @param result: A TestResult object.
1444        """
1445        return self._originalTest.run(
1446            reporter._AdaptedReporter(result, self.__class__))
1447
1448
1449
1450def _clearSuite(suite):
1451    """
1452    Clear all tests from C{suite}.
1453
1454    This messes with the internals of C{suite}. In particular, it assumes that
1455    the suite keeps all of its tests in a list in an instance variable called
1456    C{_tests}.
1457    """
1458    suite._tests = []
1459
1460
1461def decorate(test, decorator):
1462    """
1463    Decorate all test cases in C{test} with C{decorator}.
1464
1465    C{test} can be a test case or a test suite. If it is a test suite, then the
1466    structure of the suite is preserved.
1467
1468    L{decorate} tries to preserve the class of the test suites it finds, but
1469    assumes the presence of the C{_tests} attribute on the suite.
1470
1471    @param test: The C{TestCase} or C{TestSuite} to decorate.
1472
1473    @param decorator: A unary callable used to decorate C{TestCase}s.
1474
1475    @return: A decorated C{TestCase} or a C{TestSuite} containing decorated
1476        C{TestCase}s.
1477    """
1478
1479    try:
1480        tests = iter(test)
1481    except TypeError:
1482        return decorator(test)
1483
1484    # At this point, we know that 'test' is a test suite.
1485    _clearSuite(test)
1486
1487    for case in tests:
1488        test.addTest(decorate(case, decorator))
1489    return test
1490
1491
1492
1493class _PyUnitTestCaseAdapter(TestDecorator):
1494    """
1495    Adapt from pyunit.TestCase to ITestCase.
1496    """
1497
1498
1499    def visit(self, visitor):
1500        """
1501        Deprecated in Twisted 8.0.
1502        """
1503        warnings.warn("Test visitors deprecated in Twisted 8.0",
1504                      category=DeprecationWarning)
1505        visitor(self)
1506
1507
1508
1509class _BrokenIDTestCaseAdapter(_PyUnitTestCaseAdapter):
1510    """
1511    Adapter for pyunit-style C{TestCase} subclasses that have undesirable id()
1512    methods. That is L{pyunit.FunctionTestCase} and L{pyunit.DocTestCase}.
1513    """
1514
1515    def id(self):
1516        """
1517        Return the fully-qualified Python name of the doctest.
1518        """
1519        testID = self._originalTest.shortDescription()
1520        if testID is not None:
1521            return testID
1522        return self._originalTest.id()
1523
1524
1525
1526class _ForceGarbageCollectionDecorator(TestDecorator):
1527    """
1528    Forces garbage collection to be run before and after the test. Any errors
1529    logged during the post-test collection are added to the test result as
1530    errors.
1531    """
1532
1533    def run(self, result):
1534        gc.collect()
1535        TestDecorator.run(self, result)
1536        _logObserver._add()
1537        gc.collect()
1538        for error in _logObserver.getErrors():
1539            result.addError(self, error)
1540        _logObserver.flushErrors()
1541        _logObserver._remove()
1542
1543
1544components.registerAdapter(
1545    _PyUnitTestCaseAdapter, pyunit.TestCase, itrial.ITestCase)
1546
1547
1548components.registerAdapter(
1549    _BrokenIDTestCaseAdapter, pyunit.FunctionTestCase, itrial.ITestCase)
1550
1551
1552_docTestCase = getattr(doctest, 'DocTestCase', None)
1553if _docTestCase:
1554    components.registerAdapter(
1555        _BrokenIDTestCaseAdapter, _docTestCase, itrial.ITestCase)
1556
1557
1558def _iterateTests(testSuiteOrCase):
1559    """
1560    Iterate through all of the test cases in C{testSuiteOrCase}.
1561    """
1562    try:
1563        suite = iter(testSuiteOrCase)
1564    except TypeError:
1565        yield testSuiteOrCase
1566    else:
1567        for test in suite:
1568            for subtest in _iterateTests(test):
1569                yield subtest
1570
1571
1572
1573# Support for Python 2.3
1574try:
1575    iter(pyunit.TestSuite())
1576except TypeError:
1577    # Python 2.3's TestSuite doesn't support iteration. Let's monkey patch it!
1578    def __iter__(self):
1579        return iter(self._tests)
1580    pyunit.TestSuite.__iter__ = __iter__
1581
1582
1583
1584class _SubTestCase(TestCase):
1585    def __init__(self):
1586        TestCase.__init__(self, 'run')
1587
1588_inst = _SubTestCase()
1589
1590def _deprecate(name):
1591    """
1592    Internal method used to deprecate top-level assertions. Do not use this.
1593    """
1594    def _(*args, **kwargs):
1595        warnings.warn("unittest.%s is deprecated.  Instead use the %r "
1596                      "method on unittest.TestCase" % (name, name),
1597                      stacklevel=2, category=DeprecationWarning)
1598        return getattr(_inst, name)(*args, **kwargs)
1599    return _
1600
1601
1602_assertions = ['fail', 'failUnlessEqual', 'failIfEqual', 'failIfEquals',
1603               'failUnless', 'failUnlessIdentical', 'failUnlessIn',
1604               'failIfIdentical', 'failIfIn', 'failIf',
1605               'failUnlessAlmostEqual', 'failIfAlmostEqual',
1606               'failUnlessRaises', 'assertApproximates',
1607               'assertFailure', 'failUnlessSubstring', 'failIfSubstring',
1608               'assertAlmostEqual', 'assertAlmostEquals',
1609               'assertNotAlmostEqual', 'assertNotAlmostEquals', 'assertEqual',
1610               'assertEquals', 'assertNotEqual', 'assertNotEquals',
1611               'assertRaises', 'assert_', 'assertIdentical',
1612               'assertNotIdentical', 'assertIn', 'assertNotIn',
1613               'failUnlessFailure', 'assertSubstring', 'assertNotSubstring']
1614
1615
1616for methodName in _assertions:
1617    globals()[methodName] = _deprecate(methodName)
1618
1619
1620__all__ = ['TestCase', 'FailTest', 'SkipTest']
Note: See TracBrowser for help on using the browser.