Ticket #6362: test_overrideserver.py

File test_overrideserver.py, 7.7 KB (added by rwall, 11 months ago)

A test module that I wrote for some twisted.names documentation in #6864

Line 
1# Copyright (c) Twisted Matrix Laboratories.
2# See LICENSE for details.
3
4"""
5Tests for docs/names/howto/listings/auth_override.py
6"""
7
8from override_server import DynamicResolver
9
10from twisted.internet import defer, reactor
11from twisted.names import dns, error, server, client
12from twisted.trial.unittest import SynchronousTestCase, TestCase, FailTest
13
14
15
16class RaisedArguments(Exception):
17    """
18    An exception for recording raised arguments.
19    """
20    def __init__(self, args, kwargs):
21        self.args = args
22        self.kwargs = kwargs
23
24
25
26class DNSAssertionsMixin(object):
27    """
28    A custom assertion and helpers for comparing the results of IResolver.lookup
29    methods.
30
31    Compares the RRHeaders and Record payloads separately.
32    """
33    def _justPayloads(self, headers):
34        """
35        Return only the payloads from a list of headers.
36        """
37        return [h.payload for h in headers]
38
39
40    def _allPayloads(self, sections):
41        """
42        Return all the payloads from the three section lists typically returned from
43        IResolver.lookup methods.
44        """
45        payloads = []
46        for section in sections:
47            payloads.append(self._justPayloads(section))
48        return payloads
49
50
51    def assertEqualResolverResponse(self, expected, actual):
52        """
53        Compare the headers and payloads from the section lists returned by
54        IResolver.lookup methods.
55
56        Failures are accompaned by a print out of the headers and payloads.
57        """
58        try:
59            self.assertEqual(expected, actual)
60        except FailTest:
61            self.fail(
62                'Header / Payload mismatch:\n\n'
63                'Headers: \n%r\n%r\n'
64                'Payloads: \n%r\n%r\n' % (expected,
65                                          actual,
66                                          self._allPayloads(expected),
67                                          self._allPayloads(actual))
68            )
69
70
71
72class Raiser(object):
73    """
74    A fake which can be patched on top of a method under test to verify its call
75    signature.
76    """
77    def __init__(self, exception):
78        self._exception = exception
79
80
81    def call(self, *args, **kwargs):
82        raise self._exception(args, kwargs)
83
84
85
86class DynamicResolverTests(SynchronousTestCase, DNSAssertionsMixin):
87    def test_queryCallsDynamicResponseRequired(self):
88        """
89        query calls _dynamicResponseRequired with the supplied query to determine
90        whether the answer should be calculated dynamically.
91        """
92        r = DynamicResolver()
93
94        class ExpectedException(RaisedArguments):
95            pass
96
97        r._dynamicResponseRequired = Raiser(ExpectedException).call
98
99        dummyQuery = object()
100
101        e = self.assertRaises(ExpectedException, r.query, dummyQuery)
102        self.assertEqual(
103            ((dummyQuery,), {}),
104            (e.args, e.kwargs)
105        )
106
107
108    def test_dynamicResponseRequiredType(self):
109        """
110        DynamicResolver._dynamicResponseRequired returns True if query.type == A
111        else False.
112        """
113        r = DynamicResolver()
114        self.assertEqual(
115            (True, False),
116            (r._dynamicResponseRequired(dns.Query(name=b'workstation1.example.com', type=dns.A)),
117             r._dynamicResponseRequired(dns.Query(name=b'workstation1.example.com', type=dns.SOA)))
118        )
119
120
121    def test_dynamicResponseRequiredName(self):
122        """
123        DynamicResolver._dynamicResponseRequired returns True if query.name
124        begins with the word host, else False.
125        """
126        r = DynamicResolver()
127        self.assertEqual(
128            (True, False),
129            (r._dynamicResponseRequired(dns.Query(name=b'workstation1.example.com', type=dns.A)),
130             r._dynamicResponseRequired(dns.Query(name=b'foo1.example.com', type=dns.A)),)
131        )
132
133
134    def test_queryCallsDoDynamicResponse(self):
135        """
136        DynamicResolver.query will call _doDynamicResponse to calculate the response
137        to a dynamic query.
138        """
139        r = DynamicResolver()
140
141        r._dynamicResponseRequired = lambda query: True
142
143        class ExpectedException(RaisedArguments):
144            pass
145
146        r._doDynamicResponse = Raiser(ExpectedException).call
147
148        dummyQuery = object()
149        e = self.assertRaises(
150            ExpectedException,
151            r.query, dummyQuery
152        )
153        self.assertEqual(
154            ((dummyQuery,), {}),
155            (e.args, e.kwargs)
156        )
157
158
159    def test_doDynamicResponseWorkstation1(self):
160        """
161        _doDynamicResponse takes the trailing integer in the first label of the
162        query name and uses it as the last octet of the rerurned IP address.
163        """
164        r = DynamicResolver()
165        self.assertEqualResolverResponse(
166            ([dns.RRHeader(name='workstation1.example.com', payload=dns.Record_A(address='172.0.2.1', ttl=0))], [], []),
167            r._doDynamicResponse(dns.Query('workstation1.example.com'))
168        )
169
170
171    def test_doDynamicResponseWorkstation2(self):
172        """
173        """
174        r = DynamicResolver()
175        self.assertEqualResolverResponse(
176            ([dns.RRHeader(name='workstation2.example.com', payload=dns.Record_A(address='172.0.2.2', ttl=0))], [], []),
177            r._doDynamicResponse(dns.Query('workstation2.example.com'))
178        )
179
180
181    def test_querySuccess(self):
182        """
183        query returns a deferred success wrapping the results lists from
184        _doDynamicResponse.
185        """
186        r = DynamicResolver()
187        r._dynamicResponseRequired = lambda query: True
188        dummyResponse = object()
189        r._doDynamicResponse = lambda query: dummyResponse
190        res = self.successResultOf(r.query(dns.Query()))
191        self.assertIs(dummyResponse, res)
192
193
194    def test_queryDomainError(self):
195        """
196        query returns a deferred failure wrapping DomainError if the query is not to
197        be handled dynamically.
198        """
199        r = DynamicResolver()
200        r._dynamicResponseRequired = lambda query: False
201        d = r.query(dns.Query(b'foo.example.com'))
202        self.failureResultOf(d, error.DomainError)
203
204
205
206class RoundTripTests(TestCase, DNSAssertionsMixin):
207    """
208    Functional tests which setup a listening server and send it requests using a
209    network client.
210    """
211    def buildClientServer(self, fallbackResolver=None):
212        resolvers = [DynamicResolver()]
213        if fallbackResolver is not None:
214            resolvers.append(fallbackResolver)
215        s = server.DNSServerFactory(
216            authorities=resolvers,
217        )
218
219        listeningPort = reactor.listenUDP(0, dns.DNSDatagramProtocol(controller=s))
220        self.addCleanup(listeningPort.stopListening)
221        return client.Resolver(servers=[('127.0.0.1', listeningPort.getHost().port)])
222
223
224    def test_query(self):
225        """
226        """
227        hostname = b'workstation1.example.com'
228
229        expected = (
230            [dns.RRHeader(name=hostname, payload=dns.Record_A('172.0.2.1', ttl=0))],
231            [],
232            []
233        )
234
235        r = self.buildClientServer()
236
237        return r.lookupAddress(hostname).addCallback(
238            self.assertEqualResolverResponse,
239            expected
240        )
241
242
243    def test_queryFallback(self):
244        """
245        """
246        hostname = b'workstation1.example.com'
247
248        expectedAnswer = dns.RRHeader(name=hostname, type=dns.TXT, payload=dns.Record_TXT('Foo', ttl=0))
249
250        class FakeFallbackResolver(object):
251            def query(self, query, timeout):
252                return defer.succeed(([expectedAnswer], [], []))
253        fbr = FakeFallbackResolver()
254        r = self.buildClientServer(fallbackResolver=fbr)
255
256        expected = (
257            [expectedAnswer],
258            [],
259            []
260        )
261
262        return r.lookupText(hostname).addCallback(
263            self.assertEqualResolverResponse,
264            expected
265        )