root/tags/releases/twisted-8.2.0/twisted/web/client.py

Revision 24817, 15.9 KB (checked in by exarkun, 2 years ago)

Merge http-client-duplicate-headers-1382

Author: exarkun
Reviewer: therve
Fixes: #1382

Fix HTTPPageGetter so that it does not incorrectly
duplicate certain headers when they are included in
the request header dictionary passed to the factory
for that protocol.

Line 
1# -*- test-case-name: twisted.web.test.test_webclient -*-
2# Copyright (c) 2001-2008 Twisted Matrix Laboratories.
3# See LICENSE for details.
4
5"""
6HTTP client.
7"""
8
9import os, types
10from urlparse import urlunparse
11
12from twisted.web import http
13from twisted.internet import defer, protocol, reactor
14from twisted.python import failure
15from twisted.python.util import InsensitiveDict
16from twisted.web import error
17from twisted.python.compat import set
18
19
20class PartialDownloadError(error.Error):
21    """
22    Page was only partially downloaded, we got disconnected in middle.
23
24    @ivar response: All of the response body which was downloaded.
25    """
26
27
28class HTTPPageGetter(http.HTTPClient):
29
30    quietLoss = 0
31    followRedirect = 1
32    failed = 0
33
34    _specialHeaders = set(('host', 'user-agent', 'cookie', 'content-length'))
35
36    def connectionMade(self):
37        method = getattr(self.factory, 'method', 'GET')
38        self.sendCommand(method, self.factory.path)
39        self.sendHeader('Host', self.factory.headers.get("host", self.factory.host))
40        self.sendHeader('User-Agent', self.factory.agent)
41        data = getattr(self.factory, 'postdata', None)
42        if data is not None:
43            self.sendHeader("Content-Length", str(len(data)))
44
45        cookieData = []
46        for (key, value) in self.factory.headers.items():
47            if key.lower() not in self._specialHeaders:
48                # we calculated it on our own
49                self.sendHeader(key, value)
50            if key.lower() == 'cookie':
51                cookieData.append(value)
52        for cookie, cookval in self.factory.cookies.items():
53            cookieData.append('%s=%s' % (cookie, cookval))
54        if cookieData:
55            self.sendHeader('Cookie', '; '.join(cookieData))
56        self.endHeaders()
57        self.headers = {}
58
59        if data is not None:
60            self.transport.write(data)
61
62    def handleHeader(self, key, value):
63        key = key.lower()
64        l = self.headers[key] = self.headers.get(key, [])
65        l.append(value)
66
67    def handleStatus(self, version, status, message):
68        self.version, self.status, self.message = version, status, message
69        self.factory.gotStatus(version, status, message)
70
71    def handleEndHeaders(self):
72        self.factory.gotHeaders(self.headers)
73        m = getattr(self, 'handleStatus_'+self.status, self.handleStatusDefault)
74        m()
75
76    def handleStatus_200(self):
77        pass
78
79    handleStatus_201 = lambda self: self.handleStatus_200()
80    handleStatus_202 = lambda self: self.handleStatus_200()
81
82    def handleStatusDefault(self):
83        self.failed = 1
84
85    def handleStatus_301(self):
86        l = self.headers.get('location')
87        if not l:
88            self.handleStatusDefault()
89            return
90        url = l[0]
91        if self.followRedirect:
92            scheme, host, port, path = \
93                _parse(url, defaultPort=self.transport.getPeer().port)
94
95            self.factory._redirectCount += 1
96            if self.factory._redirectCount >= self.factory.redirectLimit:
97                err = error.InfiniteRedirection(
98                    self.status,
99                    'Infinite redirection detected',
100                    location=url)
101                self.factory.noPage(failure.Failure(err))
102                self.quietLoss = True
103                self.transport.loseConnection()
104                return
105
106            self.factory.setURL(url)
107
108            if self.factory.scheme == 'https':
109                from twisted.internet import ssl
110                contextFactory = ssl.ClientContextFactory()
111                reactor.connectSSL(self.factory.host, self.factory.port,
112                                   self.factory, contextFactory)
113            else:
114                reactor.connectTCP(self.factory.host, self.factory.port,
115                                   self.factory)
116        else:
117            self.handleStatusDefault()
118            self.factory.noPage(
119                failure.Failure(
120                    error.PageRedirect(
121                        self.status, self.message, location = url)))
122        self.quietLoss = True
123        self.transport.loseConnection()
124
125    handleStatus_302 = lambda self: self.handleStatus_301()
126
127    def handleStatus_303(self):
128        self.factory.method = 'GET'
129        self.handleStatus_301()
130
131    def connectionLost(self, reason):
132        if not self.quietLoss:
133            http.HTTPClient.connectionLost(self, reason)
134            self.factory.noPage(reason)
135
136    def handleResponse(self, response):
137        if self.quietLoss:
138            return
139        if self.failed:
140            self.factory.noPage(
141                failure.Failure(
142                    error.Error(
143                        self.status, self.message, response)))
144        if self.factory.method.upper() == 'HEAD':
145            # Callback with empty string, since there is never a response
146            # body for HEAD requests.
147            self.factory.page('')
148        elif self.length != None and self.length != 0:
149            self.factory.noPage(failure.Failure(
150                PartialDownloadError(self.status, self.message, response)))
151        else:
152            self.factory.page(response)
153        # server might be stupid and not close connection. admittedly
154        # the fact we do only one request per connection is also
155        # stupid...
156        self.transport.loseConnection()
157
158    def timeout(self):
159        self.quietLoss = True
160        self.transport.loseConnection()
161        self.factory.noPage(defer.TimeoutError("Getting %s took longer than %s seconds." % (self.factory.url, self.factory.timeout)))
162
163
164class HTTPPageDownloader(HTTPPageGetter):
165
166    transmittingPage = 0
167
168    def handleStatus_200(self, partialContent=0):
169        HTTPPageGetter.handleStatus_200(self)
170        self.transmittingPage = 1
171        self.factory.pageStart(partialContent)
172
173    def handleStatus_206(self):
174        self.handleStatus_200(partialContent=1)
175
176    def handleResponsePart(self, data):
177        if self.transmittingPage:
178            self.factory.pagePart(data)
179
180    def handleResponseEnd(self):
181        if self.transmittingPage:
182            self.factory.pageEnd()
183            self.transmittingPage = 0
184        if self.failed:
185            self.factory.noPage(
186                failure.Failure(
187                    error.Error(
188                        self.status, self.message, None)))
189            self.transport.loseConnection()
190
191
192class HTTPClientFactory(protocol.ClientFactory):
193    """Download a given URL.
194
195    @type deferred: Deferred
196    @ivar deferred: A Deferred that will fire when the content has
197          been retrieved. Once this is fired, the ivars `status', `version',
198          and `message' will be set.
199
200    @type status: str
201    @ivar status: The status of the response.
202
203    @type version: str
204    @ivar version: The version of the response.
205
206    @type message: str
207    @ivar message: The text message returned with the status.
208
209    @type response_headers: dict
210    @ivar response_headers: The headers that were specified in the
211          response from the server.
212
213    @type method: str
214    @ivar method: The HTTP method to use in the request.  This should be one of
215        OPTIONS, GET, HEAD, POST, PUT, DELETE, TRACE, or CONNECT (case
216        matters).  Other values may be specified if the server being contacted
217        supports them.
218
219    @type redirectLimit: int
220    @ivar redirectLimit: The maximum number of HTTP redirects that can occur
221          before it is assumed that the redirection is endless.
222
223    @type _redirectCount: int
224    @ivar _redirectCount: The current number of HTTP redirects encountered.
225    """
226
227    protocol = HTTPPageGetter
228
229    url = None
230    scheme = None
231    host = ''
232    port = None
233    path = None
234
235    def __init__(self, url, method='GET', postdata=None, headers=None,
236                 agent="Twisted PageGetter", timeout=0, cookies=None,
237                 followRedirect=1, redirectLimit=20):
238        self.protocol.followRedirect = followRedirect
239        self.redirectLimit = redirectLimit
240        self._redirectCount = 0
241        self.timeout = timeout
242        self.agent = agent
243
244        if cookies is None:
245            cookies = {}
246        self.cookies = cookies
247        if headers is not None:
248            self.headers = InsensitiveDict(headers)
249        else:
250            self.headers = InsensitiveDict()
251        if postdata is not None:
252            self.headers.setdefault('Content-Length', len(postdata))
253            # just in case a broken http/1.1 decides to keep connection alive
254            self.headers.setdefault("connection", "close")
255        self.postdata = postdata
256        self.method = method
257
258        self.setURL(url)
259
260        self.waiting = 1
261        self.deferred = defer.Deferred()
262        self.response_headers = None
263
264    def __repr__(self):
265        return "<%s: %s>" % (self.__class__.__name__, self.url)
266
267    def setURL(self, url):
268        self.url = url
269        scheme, host, port, path = _parse(url)
270        if scheme and host:
271            self.scheme = scheme
272            self.host = host
273            self.port = port
274        self.path = path
275
276    def buildProtocol(self, addr):
277        p = protocol.ClientFactory.buildProtocol(self, addr)
278        if self.timeout:
279            timeoutCall = reactor.callLater(self.timeout, p.timeout)
280            self.deferred.addBoth(self._cancelTimeout, timeoutCall)
281        return p
282
283    def _cancelTimeout(self, result, timeoutCall):
284        if timeoutCall.active():
285            timeoutCall.cancel()
286        return result
287
288    def gotHeaders(self, headers):
289        self.response_headers = headers
290        if headers.has_key('set-cookie'):
291            for cookie in headers['set-cookie']:
292                cookparts = cookie.split(';')
293                cook = cookparts[0]
294                cook.lstrip()
295                k, v = cook.split('=', 1)
296                self.cookies[k.lstrip()] = v.lstrip()
297
298    def gotStatus(self, version, status, message):
299        self.version, self.status, self.message = version, status, message
300
301    def page(self, page):
302        if self.waiting:
303            self.waiting = 0
304            self.deferred.callback(page)
305
306    def noPage(self, reason):
307        if self.waiting:
308            self.waiting = 0
309            self.deferred.errback(reason)
310
311    def clientConnectionFailed(self, _, reason):
312        if self.waiting:
313            self.waiting = 0
314            self.deferred.errback(reason)
315
316
317class HTTPDownloader(HTTPClientFactory):
318    """Download to a file."""
319
320    protocol = HTTPPageDownloader
321    value = None
322
323    def __init__(self, url, fileOrName,
324                 method='GET', postdata=None, headers=None,
325                 agent="Twisted client", supportPartial=0):
326        self.requestedPartial = 0
327        if isinstance(fileOrName, types.StringTypes):
328            self.fileName = fileOrName
329            self.file = None
330            if supportPartial and os.path.exists(self.fileName):
331                fileLength = os.path.getsize(self.fileName)
332                if fileLength:
333                    self.requestedPartial = fileLength
334                    if headers == None:
335                        headers = {}
336                    headers["range"] = "bytes=%d-" % fileLength
337        else:
338            self.file = fileOrName
339        HTTPClientFactory.__init__(self, url, method=method, postdata=postdata, headers=headers, agent=agent)
340        self.deferred = defer.Deferred()
341        self.waiting = 1
342
343    def gotHeaders(self, headers):
344        if self.requestedPartial:
345            contentRange = headers.get("content-range", None)
346            if not contentRange:
347                # server doesn't support partial requests, oh well
348                self.requestedPartial = 0
349                return
350            start, end, realLength = http.parseContentRange(contentRange[0])
351            if start != self.requestedPartial:
352                # server is acting wierdly
353                self.requestedPartial = 0
354
355    def openFile(self, partialContent):
356        if partialContent:
357            file = open(self.fileName, 'rb+')
358            file.seek(0, 2)
359        else:
360            file = open(self.fileName, 'wb')
361        return file
362
363    def pageStart(self, partialContent):
364        """Called on page download start.
365
366        @param partialContent: tells us if the download is partial download we requested.
367        """
368        if partialContent and not self.requestedPartial:
369            raise ValueError, "we shouldn't get partial content response if we didn't want it!"
370        if self.waiting:
371            self.waiting = 0
372            try:
373                if not self.file:
374                    self.file = self.openFile(partialContent)
375            except IOError:
376                #raise
377                self.deferred.errback(failure.Failure())
378
379    def pagePart(self, data):
380        if not self.file:
381            return
382        try:
383            self.file.write(data)
384        except IOError:
385            #raise
386            self.file = None
387            self.deferred.errback(failure.Failure())
388
389    def pageEnd(self):
390        if not self.file:
391            return
392        try:
393            self.file.close()
394        except IOError:
395            self.deferred.errback(failure.Failure())
396            return
397        self.deferred.callback(self.value)
398
399
400def _parse(url, defaultPort=None):
401    """
402    Split the given URL into the scheme, host, port, and path.
403
404    @type url: C{str}
405    @param url: An URL to parse.
406
407    @type defaultPort: C{int} or C{None}
408    @param defaultPort: An alternate value to use as the port if the URL does
409    not include one.
410
411    @return: A four-tuple of the scheme, host, port, and path of the URL.  All
412    of these are C{str} instances except for port, which is an C{int}.
413    """
414    url = url.strip()
415    parsed = http.urlparse(url)
416    scheme = parsed[0]
417    path = urlunparse(('','')+parsed[2:])
418    if defaultPort is None:
419        if scheme == 'https':
420            defaultPort = 443
421        else:
422            defaultPort = 80
423    host, port = parsed[1], defaultPort
424    if ':' in host:
425        host, port = host.split(':')
426        port = int(port)
427    if path == "":
428        path = "/"
429    return scheme, host, port, path
430
431
432def _makeGetterFactory(url, factoryFactory, contextFactory=None,
433                       *args, **kwargs):
434    """
435    Create and connect an HTTP page getting factory.
436
437    Any additional positional or keyword arguments are used when calling
438    C{factoryFactory}.
439
440    @param factoryFactory: Factory factory that is called with C{url}, C{args}
441        and C{kwargs} to produce the getter
442
443    @param contextFactory: Context factory to use when creating a secure
444        connection, defaulting to C{None}
445
446    @return: The factory created by C{factoryFactory}
447    """
448    scheme, host, port, path = _parse(url)
449    factory = factoryFactory(url, *args, **kwargs)
450    if scheme == 'https':
451        from twisted.internet import ssl
452        if contextFactory is None:
453            contextFactory = ssl.ClientContextFactory()
454        reactor.connectSSL(host, port, factory, contextFactory)
455    else:
456        reactor.connectTCP(host, port, factory)
457    return factory
458
459
460def getPage(url, contextFactory=None, *args, **kwargs):
461    """
462    Download a web page as a string.
463
464    Download a page. Return a deferred, which will callback with a
465    page (as a string) or errback with a description of the error.
466
467    See HTTPClientFactory to see what extra args can be passed.
468    """
469    return _makeGetterFactory(
470        url,
471        HTTPClientFactory,
472        contextFactory=contextFactory,
473        *args, **kwargs).deferred
474
475
476def downloadPage(url, file, contextFactory=None, *args, **kwargs):
477    """
478    Download a web page to a file.
479
480    @param file: path to file on filesystem, or file-like object.
481
482    See HTTPDownloader to see what extra args can be passed.
483    """
484    factoryFactory = lambda url, *a, **kw: HTTPDownloader(url, file, *a, **kw)
485    return _makeGetterFactory(
486        url,
487        factoryFactory,
488        contextFactory=contextFactory,
489        *args, **kwargs).deferred
490
491
492__all__ = [
493    'PartialDownloadError',
494    'HTTPPageGetter', 'HTTPPageDownloader', 'HTTPClientFactory', 'HTTPDownloader',
495    'getPage', 'downloadPage']
Note: See TracBrowser for help on using the browser.