root/trunk/twisted/python/test/test_util.py

Revision 29907, 26.2 KB (checked in by mithrandi, 4 weeks ago)

Reopens #4536

The tests don't work on Python 2.4.

Line 
1# -*- test-case-name: twisted.test.test_util -*-
2# Copyright (c) 2001-2010 Twisted Matrix Laboratories.
3# See LICENSE for details.
4
5import os.path, sys
6import shutil, errno
7try:
8    import pwd, grp
9except ImportError:
10    pwd = grp = None
11
12from twisted.trial import unittest
13
14from twisted.python import util
15from twisted.internet import reactor
16from twisted.internet.interfaces import IReactorProcess
17from twisted.internet.protocol import ProcessProtocol
18from twisted.internet.defer import Deferred
19from twisted.internet.error import ProcessDone
20
21from twisted.test.test_process import MockOS
22
23
24
25class UtilTestCase(unittest.TestCase):
26
27    def testUniq(self):
28        l = ["a", 1, "ab", "a", 3, 4, 1, 2, 2, 4, 6]
29        self.assertEquals(util.uniquify(l), ["a", 1, "ab", 3, 4, 2, 6])
30
31    def testRaises(self):
32        self.failUnless(util.raises(ZeroDivisionError, divmod, 1, 0))
33        self.failIf(util.raises(ZeroDivisionError, divmod, 0, 1))
34
35        try:
36            util.raises(TypeError, divmod, 1, 0)
37        except ZeroDivisionError:
38            pass
39        else:
40            raise unittest.FailTest, "util.raises didn't raise when it should have"
41
42    def testUninterruptably(self):
43        def f(a, b):
44            self.calls += 1
45            exc = self.exceptions.pop()
46            if exc is not None:
47                raise exc(errno.EINTR, "Interrupted system call!")
48            return a + b
49
50        self.exceptions = [None]
51        self.calls = 0
52        self.assertEquals(util.untilConcludes(f, 1, 2), 3)
53        self.assertEquals(self.calls, 1)
54
55        self.exceptions = [None, OSError, IOError]
56        self.calls = 0
57        self.assertEquals(util.untilConcludes(f, 2, 3), 5)
58        self.assertEquals(self.calls, 3)
59
60    def testNameToLabel(self):
61        """
62        Test the various kinds of inputs L{nameToLabel} supports.
63        """
64        nameData = [
65            ('f', 'F'),
66            ('fo', 'Fo'),
67            ('foo', 'Foo'),
68            ('fooBar', 'Foo Bar'),
69            ('fooBarBaz', 'Foo Bar Baz'),
70            ]
71        for inp, out in nameData:
72            got = util.nameToLabel(inp)
73            self.assertEquals(
74                got, out,
75                "nameToLabel(%r) == %r != %r" % (inp, got, out))
76
77
78    def test_uidFromNumericString(self):
79        """
80        When L{uidFromString} is called with a base-ten string representation
81        of an integer, it returns the integer.
82        """
83        self.assertEqual(util.uidFromString("100"), 100)
84
85
86    def test_uidFromUsernameString(self):
87        """
88        When L{uidFromString} is called with a base-ten string representation
89        of an integer, it returns the integer.
90        """
91        pwent = pwd.getpwuid(os.getuid())
92        self.assertEqual(util.uidFromString(pwent.pw_name), pwent.pw_uid)
93    if pwd is None:
94        test_uidFromUsernameString.skip = (
95            "Username/UID conversion requires the pwd module.")
96
97
98    def test_gidFromNumericString(self):
99        """
100        When L{gidFromString} is called with a base-ten string representation
101        of an integer, it returns the integer.
102        """
103        self.assertEqual(util.gidFromString("100"), 100)
104
105
106    def test_gidFromGroupnameString(self):
107        """
108        When L{gidFromString} is called with a base-ten string representation
109        of an integer, it returns the integer.
110        """
111        grent = grp.getgrgid(os.getgid())
112        self.assertEqual(util.gidFromString(grent.gr_name), grent.gr_gid)
113    if grp is None:
114        test_gidFromGroupnameString.skip = (
115            "Group Name/GID conversion requires the grp module.")
116
117
118    def test_moduleMovedForSplitDeprecation(self):
119        """
120        Calling L{moduleMovedForSplit} results in a deprecation warning.
121        """
122        util.moduleMovedForSplit("foo", "bar", "baz", "quux", "corge", {})
123        warnings = self.flushWarnings(
124            offendingFunctions=[self.test_moduleMovedForSplitDeprecation])
125        self.assertEquals(
126            warnings[0]['message'],
127            "moduleMovedForSplit is deprecated since Twisted 9.0.")
128        self.assertEquals(warnings[0]['category'], DeprecationWarning)
129        self.assertEquals(len(warnings), 1)
130
131
132
133class TestMergeFunctionMetadata(unittest.TestCase):
134    """
135    Tests for L{mergeFunctionMetadata}.
136    """
137
138    def test_mergedFunctionBehavesLikeMergeTarget(self):
139        """
140        After merging C{foo}'s data into C{bar}, the returned function behaves
141        as if it is C{bar}.
142        """
143        foo_object = object()
144        bar_object = object()
145
146        def foo():
147            return foo_object
148
149        def bar(x, y, (a, b), c=10, *d, **e):
150            return bar_object
151
152        baz = util.mergeFunctionMetadata(foo, bar)
153        self.assertIdentical(baz(1, 2, (3, 4), quux=10), bar_object)
154
155
156    def test_moduleIsMerged(self):
157        """
158        Merging C{foo} into C{bar} returns a function with C{foo}'s
159        C{__module__}.
160        """
161        def foo():
162            pass
163
164        def bar():
165            pass
166        bar.__module__ = 'somewhere.else'
167
168        baz = util.mergeFunctionMetadata(foo, bar)
169        self.assertEqual(baz.__module__, foo.__module__)
170
171
172    def test_docstringIsMerged(self):
173        """
174        Merging C{foo} into C{bar} returns a function with C{foo}'s docstring.
175        """
176
177        def foo():
178            """
179            This is foo.
180            """
181
182        def bar():
183            """
184            This is bar.
185            """
186
187        baz = util.mergeFunctionMetadata(foo, bar)
188        self.assertEqual(baz.__doc__, foo.__doc__)
189
190
191    def test_nameIsMerged(self):
192        """
193        Merging C{foo} into C{bar} returns a function with C{foo}'s name.
194        """
195
196        def foo():
197            pass
198
199        def bar():
200            pass
201
202        baz = util.mergeFunctionMetadata(foo, bar)
203        self.assertEqual(baz.__name__, foo.__name__)
204
205
206    def test_instanceDictionaryIsMerged(self):
207        """
208        Merging C{foo} into C{bar} returns a function with C{bar}'s
209        dictionary, updated by C{foo}'s.
210        """
211
212        def foo():
213            pass
214        foo.a = 1
215        foo.b = 2
216
217        def bar():
218            pass
219        bar.b = 3
220        bar.c = 4
221
222        baz = util.mergeFunctionMetadata(foo, bar)
223        self.assertEqual(foo.a, baz.a)
224        self.assertEqual(foo.b, baz.b)
225        self.assertEqual(bar.c, baz.c)
226
227
228
229class OrderedDictTest(unittest.TestCase):
230    def testOrderedDict(self):
231        d = util.OrderedDict()
232        d['a'] = 'b'
233        d['b'] = 'a'
234        d[3] = 12
235        d[1234] = 4321
236        self.assertEquals(repr(d), "{'a': 'b', 'b': 'a', 3: 12, 1234: 4321}")
237        self.assertEquals(d.values(), ['b', 'a', 12, 4321])
238        del d[3]
239        self.assertEquals(repr(d), "{'a': 'b', 'b': 'a', 1234: 4321}")
240        self.assertEquals(d, {'a': 'b', 'b': 'a', 1234:4321})
241        self.assertEquals(d.keys(), ['a', 'b', 1234])
242        self.assertEquals(list(d.iteritems()),
243                          [('a', 'b'), ('b','a'), (1234, 4321)])
244        item = d.popitem()
245        self.assertEquals(item, (1234, 4321))
246
247    def testInitialization(self):
248        d = util.OrderedDict({'monkey': 'ook',
249                              'apple': 'red'})
250        self.failUnless(d._order)
251
252        d = util.OrderedDict(((1,1),(3,3),(2,2),(0,0)))
253        self.assertEquals(repr(d), "{1: 1, 3: 3, 2: 2, 0: 0}")
254
255class InsensitiveDictTest(unittest.TestCase):
256    def testPreserve(self):
257        InsensitiveDict=util.InsensitiveDict
258        dct=InsensitiveDict({'Foo':'bar', 1:2, 'fnz':{1:2}}, preserve=1)
259        self.assertEquals(dct['fnz'], {1:2})
260        self.assertEquals(dct['foo'], 'bar')
261        self.assertEquals(dct.copy(), dct)
262        self.assertEquals(dct['foo'], dct.get('Foo'))
263        assert 1 in dct and 'foo' in dct
264        self.assertEquals(eval(repr(dct)), dct)
265        keys=['Foo', 'fnz', 1]
266        for x in keys:
267            assert x in dct.keys()
268            assert (x, dct[x]) in dct.items()
269        self.assertEquals(len(keys), len(dct))
270        del dct[1]
271        del dct['foo']
272
273    def testNoPreserve(self):
274        InsensitiveDict=util.InsensitiveDict
275        dct=InsensitiveDict({'Foo':'bar', 1:2, 'fnz':{1:2}}, preserve=0)
276        keys=['foo', 'fnz', 1]
277        for x in keys:
278            assert x in dct.keys()
279            assert (x, dct[x]) in dct.items()
280        self.assertEquals(len(keys), len(dct))
281        del dct[1]
282        del dct['foo']
283
284
285
286
287class PasswordTestingProcessProtocol(ProcessProtocol):
288    """
289    Write the string C{"secret\n"} to a subprocess and then collect all of
290    its output and fire a Deferred with it when the process ends.
291    """
292    def connectionMade(self):
293        self.output = []
294        self.transport.write('secret\n')
295
296    def childDataReceived(self, fd, output):
297        self.output.append((fd, output))
298
299    def processEnded(self, reason):
300        self.finished.callback((reason, self.output))
301
302
303class GetPasswordTest(unittest.TestCase):
304    if not IReactorProcess.providedBy(reactor):
305        skip = "Process support required to test getPassword"
306
307    def test_stdin(self):
308        """
309        Making sure getPassword accepts a password from standard input by
310        running a child process which uses getPassword to read in a string
311        which it then writes it out again.  Write a string to the child
312        process and then read one and make sure it is the right string.
313        """
314        p = PasswordTestingProcessProtocol()
315        p.finished = Deferred()
316        reactor.spawnProcess(
317            p,
318            sys.executable,
319            [sys.executable,
320             '-c',
321             ('import sys\n'
322             'from twisted.python.util import getPassword\n'
323              'sys.stdout.write(getPassword())\n'
324              'sys.stdout.flush()\n')],
325            env={'PYTHONPATH': os.pathsep.join(sys.path)})
326
327        def processFinished((reason, output)):
328            reason.trap(ProcessDone)
329            self.assertIn((1, 'secret'), output)
330
331        return p.finished.addCallback(processFinished)
332
333
334
335class SearchUpwardsTest(unittest.TestCase):
336    def testSearchupwards(self):
337        os.makedirs('searchupwards/a/b/c')
338        file('searchupwards/foo.txt', 'w').close()
339        file('searchupwards/a/foo.txt', 'w').close()
340        file('searchupwards/a/b/c/foo.txt', 'w').close()
341        os.mkdir('searchupwards/bar')
342        os.mkdir('searchupwards/bam')
343        os.mkdir('searchupwards/a/bar')
344        os.mkdir('searchupwards/a/b/bam')
345        actual=util.searchupwards('searchupwards/a/b/c',
346                                  files=['foo.txt'],
347                                  dirs=['bar', 'bam'])
348        expected=os.path.abspath('searchupwards') + os.sep
349        self.assertEqual(actual, expected)
350        shutil.rmtree('searchupwards')
351        actual=util.searchupwards('searchupwards/a/b/c',
352                                  files=['foo.txt'],
353                                  dirs=['bar', 'bam'])
354        expected=None
355        self.assertEqual(actual, expected)
356
357class Foo:
358    def __init__(self, x):
359        self.x = x
360
361class DSU(unittest.TestCase):
362    """
363    Tests for L{util.dsu}
364    """
365    def test_dsu(self):
366        L = [Foo(x) for x in range(20, 9, -1)]
367        L2 = util.dsu(L, lambda o: o.x)
368        self.assertEquals(range(10, 21), [o.x for o in L2])
369
370
371    def test_deprecation(self):
372        self.assertWarns(DeprecationWarning,
373                         ("dsu is deprecated since Twisted 10.1. "
374                          "Use the built-in sorted() instead."),
375                         __file__, lambda: util.dsu([], lambda: 0))
376
377
378
379class IntervalDifferentialTestCase(unittest.TestCase):
380    def testDefault(self):
381        d = iter(util.IntervalDifferential([], 10))
382        for i in range(100):
383            self.assertEquals(d.next(), (10, None))
384
385    def testSingle(self):
386        d = iter(util.IntervalDifferential([5], 10))
387        for i in range(100):
388            self.assertEquals(d.next(), (5, 0))
389
390    def testPair(self):
391        d = iter(util.IntervalDifferential([5, 7], 10))
392        for i in range(100):
393            self.assertEquals(d.next(), (5, 0))
394            self.assertEquals(d.next(), (2, 1))
395            self.assertEquals(d.next(), (3, 0))
396            self.assertEquals(d.next(), (4, 1))
397            self.assertEquals(d.next(), (1, 0))
398            self.assertEquals(d.next(), (5, 0))
399            self.assertEquals(d.next(), (1, 1))
400            self.assertEquals(d.next(), (4, 0))
401            self.assertEquals(d.next(), (3, 1))
402            self.assertEquals(d.next(), (2, 0))
403            self.assertEquals(d.next(), (5, 0))
404            self.assertEquals(d.next(), (0, 1))
405
406    def testTriple(self):
407        d = iter(util.IntervalDifferential([2, 4, 5], 10))
408        for i in range(100):
409            self.assertEquals(d.next(), (2, 0))
410            self.assertEquals(d.next(), (2, 0))
411            self.assertEquals(d.next(), (0, 1))
412            self.assertEquals(d.next(), (1, 2))
413            self.assertEquals(d.next(), (1, 0))
414            self.assertEquals(d.next(), (2, 0))
415            self.assertEquals(d.next(), (0, 1))
416            self.assertEquals(d.next(), (2, 0))
417            self.assertEquals(d.next(), (0, 2))
418            self.assertEquals(d.next(), (2, 0))
419            self.assertEquals(d.next(), (0, 1))
420            self.assertEquals(d.next(), (2, 0))
421            self.assertEquals(d.next(), (1, 2))
422            self.assertEquals(d.next(), (1, 0))
423            self.assertEquals(d.next(), (0, 1))
424            self.assertEquals(d.next(), (2, 0))
425            self.assertEquals(d.next(), (2, 0))
426            self.assertEquals(d.next(), (0, 1))
427            self.assertEquals(d.next(), (0, 2))
428
429    def testInsert(self):
430        d = iter(util.IntervalDifferential([], 10))
431        self.assertEquals(d.next(), (10, None))
432        d.addInterval(3)
433        self.assertEquals(d.next(), (3, 0))
434        self.assertEquals(d.next(), (3, 0))
435        d.addInterval(6)
436        self.assertEquals(d.next(), (3, 0))
437        self.assertEquals(d.next(), (3, 0))
438        self.assertEquals(d.next(), (0, 1))
439        self.assertEquals(d.next(), (3, 0))
440        self.assertEquals(d.next(), (3, 0))
441        self.assertEquals(d.next(), (0, 1))
442
443    def testRemove(self):
444        d = iter(util.IntervalDifferential([3, 5], 10))
445        self.assertEquals(d.next(), (3, 0))
446        self.assertEquals(d.next(), (2, 1))
447        self.assertEquals(d.next(), (1, 0))
448        d.removeInterval(3)
449        self.assertEquals(d.next(), (4, 0))
450        self.assertEquals(d.next(), (5, 0))
451        d.removeInterval(5)
452        self.assertEquals(d.next(), (10, None))
453        self.assertRaises(ValueError, d.removeInterval, 10)
454
455
456
457class Record(util.FancyEqMixin):
458    """
459    Trivial user of L{FancyEqMixin} used by tests.
460    """
461    compareAttributes = ('a', 'b')
462
463    def __init__(self, a, b):
464        self.a = a
465        self.b = b
466
467
468
469class DifferentRecord(util.FancyEqMixin):
470    """
471    Trivial user of L{FancyEqMixin} which is not related to L{Record}.
472    """
473    compareAttributes = ('a', 'b')
474
475    def __init__(self, a, b):
476        self.a = a
477        self.b = b
478
479
480
481class DerivedRecord(Record):
482    """
483    A class with an inheritance relationship to L{Record}.
484    """
485
486
487
488class EqualToEverything(object):
489    """
490    A class the instances of which consider themselves equal to everything.
491    """
492    def __eq__(self, other):
493        return True
494
495
496    def __ne__(self, other):
497        return False
498
499
500
501class EqualToNothing(object):
502    """
503    A class the instances of which consider themselves equal to nothing.
504    """
505    def __eq__(self, other):
506        return False
507
508
509    def __ne__(self, other):
510        return True
511
512
513
514class EqualityTests(unittest.TestCase):
515    """
516    Tests for L{FancyEqMixin}.
517    """
518    def test_identity(self):
519        """
520        Instances of a class which mixes in L{FancyEqMixin} but which
521        defines no comparison attributes compare by identity.
522        """
523        class Empty(util.FancyEqMixin):
524            pass
525
526        self.assertFalse(Empty() == Empty())
527        self.assertTrue(Empty() != Empty())
528        empty = Empty()
529        self.assertTrue(empty == empty)
530        self.assertFalse(empty != empty)
531
532
533    def test_equality(self):
534        """
535        Instances of a class which mixes in L{FancyEqMixin} should compare
536        equal if all of their attributes compare equal.  They should not
537        compare equal if any of their attributes do not compare equal.
538        """
539        self.assertTrue(Record(1, 2) == Record(1, 2))
540        self.assertFalse(Record(1, 2) == Record(1, 3))
541        self.assertFalse(Record(1, 2) == Record(2, 2))
542        self.assertFalse(Record(1, 2) == Record(3, 4))
543
544
545    def test_unequality(self):
546        """
547        Unequality between instances of a particular L{record} should be
548        defined as the negation of equality.
549        """
550        self.assertFalse(Record(1, 2) != Record(1, 2))
551        self.assertTrue(Record(1, 2) != Record(1, 3))
552        self.assertTrue(Record(1, 2) != Record(2, 2))
553        self.assertTrue(Record(1, 2) != Record(3, 4))
554
555
556    def test_differentClassesEquality(self):
557        """
558        Instances of different classes which mix in L{FancyEqMixin} should not
559        compare equal.
560        """
561        self.assertFalse(Record(1, 2) == DifferentRecord(1, 2))
562
563
564    def test_differentClassesInequality(self):
565        """
566        Instances of different classes which mix in L{FancyEqMixin} should
567        compare unequal.
568        """
569        self.assertTrue(Record(1, 2) != DifferentRecord(1, 2))
570
571
572    def test_inheritedClassesEquality(self):
573        """
574        An instance of a class which derives from a class which mixes in
575        L{FancyEqMixin} should compare equal to an instance of the base class
576        if and only if all of their attributes compare equal.
577        """
578        self.assertTrue(Record(1, 2) == DerivedRecord(1, 2))
579        self.assertFalse(Record(1, 2) == DerivedRecord(1, 3))
580        self.assertFalse(Record(1, 2) == DerivedRecord(2, 2))
581        self.assertFalse(Record(1, 2) == DerivedRecord(3, 4))
582
583
584    def test_inheritedClassesInequality(self):
585        """
586        An instance of a class which derives from a class which mixes in
587        L{FancyEqMixin} should compare unequal to an instance of the base
588        class if any of their attributes compare unequal.
589        """
590        self.assertFalse(Record(1, 2) != DerivedRecord(1, 2))
591        self.assertTrue(Record(1, 2) != DerivedRecord(1, 3))
592        self.assertTrue(Record(1, 2) != DerivedRecord(2, 2))
593        self.assertTrue(Record(1, 2) != DerivedRecord(3, 4))
594
595
596    def test_rightHandArgumentImplementsEquality(self):
597        """
598        The right-hand argument to the equality operator is given a chance
599        to determine the result of the operation if it is of a type
600        unrelated to the L{FancyEqMixin}-based instance on the left-hand
601        side.
602        """
603        self.assertTrue(Record(1, 2) == EqualToEverything())
604        self.assertFalse(Record(1, 2) == EqualToNothing())
605
606
607    def test_rightHandArgumentImplementsUnequality(self):
608        """
609        The right-hand argument to the non-equality operator is given a
610        chance to determine the result of the operation if it is of a type
611        unrelated to the L{FancyEqMixin}-based instance on the left-hand
612        side.
613        """
614        self.assertFalse(Record(1, 2) != EqualToEverything())
615        self.assertTrue(Record(1, 2) != EqualToNothing())
616
617
618
619class RunAsEffectiveUserTests(unittest.TestCase):
620    """
621    Test for the L{util.runAsEffectiveUser} function.
622    """
623
624    if getattr(os, "geteuid", None) is None:
625        skip = "geteuid/seteuid not available"
626
627    def setUp(self):
628        self.mockos = MockOS()
629        self.patch(os, "geteuid", self.mockos.geteuid)
630        self.patch(os, "getegid", self.mockos.getegid)
631        self.patch(os, "seteuid", self.mockos.seteuid)
632        self.patch(os, "setegid", self.mockos.setegid)
633
634
635    def _securedFunction(self, startUID, startGID, wantUID, wantGID):
636        """
637        Check if wanted UID/GID matched start or saved ones.
638        """
639        self.assertTrue(wantUID == startUID or
640                        wantUID == self.mockos.seteuidCalls[-1])
641        self.assertTrue(wantGID == startGID or
642                        wantGID == self.mockos.setegidCalls[-1])
643
644
645    def test_forwardResult(self):
646        """
647        L{util.runAsEffectiveUser} forwards the result obtained by calling the
648        given function
649        """
650        result = util.runAsEffectiveUser(0, 0, lambda: 1)
651        self.assertEquals(result, 1)
652
653
654    def test_takeParameters(self):
655        """
656        L{util.runAsEffectiveUser} pass the given parameters to the given
657        function.
658        """
659        result = util.runAsEffectiveUser(0, 0, lambda x: 2*x, 3)
660        self.assertEquals(result, 6)
661
662
663    def test_takesKeyworkArguments(self):
664        """
665        L{util.runAsEffectiveUser} pass the keyword parameters to the given
666        function.
667        """
668        result = util.runAsEffectiveUser(0, 0, lambda x, y=1, z=1: x*y*z, 2, z=3)
669        self.assertEquals(result, 6)
670
671
672    def _testUIDGIDSwitch(self, startUID, startGID, wantUID, wantGID,
673                          expectedUIDSwitches, expectedGIDSwitches):
674        """
675        Helper method checking the calls to C{os.seteuid} and C{os.setegid}
676        made by L{util.runAsEffectiveUser}, when switching from startUID to
677        wantUID and from startGID to wantGID.
678        """
679        self.mockos.euid = startUID
680        self.mockos.egid = startGID
681        util.runAsEffectiveUser(
682            wantUID, wantGID,
683            self._securedFunction, startUID, startGID, wantUID, wantGID)
684        self.assertEquals(self.mockos.seteuidCalls, expectedUIDSwitches)
685        self.assertEquals(self.mockos.setegidCalls, expectedGIDSwitches)
686        self.mockos.seteuidCalls = []
687        self.mockos.setegidCalls = []
688
689
690    def test_root(self):
691        """
692        Check UID/GID switches when current effective UID is root.
693        """
694        self._testUIDGIDSwitch(0, 0, 0, 0, [], [])
695        self._testUIDGIDSwitch(0, 0, 1, 0, [1, 0], [])
696        self._testUIDGIDSwitch(0, 0, 0, 1, [], [1, 0])
697        self._testUIDGIDSwitch(0, 0, 1, 1, [1, 0], [1, 0])
698
699
700    def test_UID(self):
701        """
702        Check UID/GID switches when current effective UID is non-root.
703        """
704        self._testUIDGIDSwitch(1, 0, 0, 0, [0, 1], [])
705        self._testUIDGIDSwitch(1, 0, 1, 0, [], [])
706        self._testUIDGIDSwitch(1, 0, 1, 1, [0, 1, 0, 1], [1, 0])
707        self._testUIDGIDSwitch(1, 0, 2, 1, [0, 2, 0, 1], [1, 0])
708
709
710    def test_GID(self):
711        """
712        Check UID/GID switches when current effective GID is non-root.
713        """
714        self._testUIDGIDSwitch(0, 1, 0, 0, [], [0, 1])
715        self._testUIDGIDSwitch(0, 1, 0, 1, [], [])
716        self._testUIDGIDSwitch(0, 1, 1, 1, [1, 0], [])
717        self._testUIDGIDSwitch(0, 1, 1, 2, [1, 0], [2, 1])
718
719
720    def test_UIDGID(self):
721        """
722        Check UID/GID switches when current effective UID/GID is non-root.
723        """
724        self._testUIDGIDSwitch(1, 1, 0, 0, [0, 1], [0, 1])
725        self._testUIDGIDSwitch(1, 1, 0, 1, [0, 1], [])
726        self._testUIDGIDSwitch(1, 1, 1, 0, [0, 1, 0, 1], [0, 1])
727        self._testUIDGIDSwitch(1, 1, 1, 1, [], [])
728        self._testUIDGIDSwitch(1, 1, 2, 1, [0, 2, 0, 1], [])
729        self._testUIDGIDSwitch(1, 1, 1, 2, [0, 1, 0, 1], [2, 1])
730        self._testUIDGIDSwitch(1, 1, 2, 2, [0, 2, 0, 1], [2, 1])
731
732
733
734class UnsignedIDTests(unittest.TestCase):
735    """
736    Tests for L{util.unsignedID} and L{util.setIDFunction}.
737    """
738    def setUp(self):
739        """
740        Save the value of L{util._idFunction} and arrange for it to be restored
741        after the test runs.
742        """
743        self.addCleanup(setattr, util, '_idFunction', util._idFunction)
744
745
746    def test_setIDFunction(self):
747        """
748        L{util.setIDFunction} returns the last value passed to it.
749        """
750        value = object()
751        previous = util.setIDFunction(value)
752        result = util.setIDFunction(previous)
753        self.assertIdentical(value, result)
754
755
756    def test_unsignedID(self):
757        """
758        L{util.unsignedID} uses the function passed to L{util.setIDFunction} to
759        determine the unique integer id of an object and then adjusts it to be
760        positive if necessary.
761        """
762        foo = object()
763        bar = object()
764
765        # A fake object identity mapping
766        objects = {foo: 17, bar: -73}
767        def fakeId(obj):
768            return objects[obj]
769
770        util.setIDFunction(fakeId)
771
772        self.assertEquals(util.unsignedID(foo), 17)
773        self.assertEquals(util.unsignedID(bar), (sys.maxint + 1) * 2 - 73)
774
775
776    def test_defaultIDFunction(self):
777        """
778        L{util.unsignedID} uses the built in L{id} by default.
779        """
780        obj = object()
781        idValue = id(obj)
782        if idValue < 0:
783            idValue += (sys.maxint + 1) * 2
784
785        self.assertEquals(util.unsignedID(obj), idValue)
786
787
788
789class InitGroupsTests(unittest.TestCase):
790    """
791    Tests for L{util.initgroups}.
792    """
793
794    if pwd is None:
795        skip = "pwd not available"
796
797
798    def setUp(self):
799        self.addCleanup(setattr, util, "_c_initgroups", util._c_initgroups)
800        self.addCleanup(setattr, util, "setgroups", util.setgroups)
801
802
803    def test_initgroupsForceC(self):
804        """
805        If we fake the presence of the C extension, it's called instead of the
806        Python implementation.
807        """
808        calls = []
809        util._c_initgroups = lambda x, y: calls.append((x, y))
810        setgroupsCalls = []
811        util.setgroups = calls.append
812
813        util.initgroups(os.getuid(), 4)
814        self.assertEquals(calls, [(pwd.getpwuid(os.getuid())[0], 4)])
815        self.assertFalse(setgroupsCalls)
816
817
818    def test_initgroupsForcePython(self):
819        """
820        If we fake the absence of the C extension, the Python implementation is
821        called instead, calling C{os.setgroups}.
822        """
823        util._c_initgroups = None
824        calls = []
825        util.setgroups = calls.append
826        util.initgroups(os.getuid(), os.getgid())
827        # Something should be in the calls, we don't really care what
828        self.assertTrue(calls)
829
830
831    def test_initgroupsInC(self):
832        """
833        If the C extension is present, it's called instead of the Python
834        version.  We check that by making sure C{os.setgroups} is not called.
835        """
836        calls = []
837        util.setgroups = calls.append
838        try:
839            util.initgroups(os.getuid(), os.getgid())
840        except OSError:
841            pass
842        self.assertFalse(calls)
843
844
845    if util._c_initgroups is None:
846        test_initgroupsInC.skip = "C initgroups not available"
Note: See TracBrowser for help on using the browser.