Ticket #612: httpcache.py

File httpcache.py, 4.5 KB (added by dialtone, 11 years ago)
Line 
1# twisted.web.client.py
2
3class HTTPCacheDownloader(HTTPPageGetter):
4    def connectionMade(self, isCached=False):
5        method = getattr(self.factory, 'method', 'GET')
6        self.sendCommand(method, self.factory.path)
7        self.sendHeader('Host', self.factory.headers.get("host", self.factory.host))
8        self.sendHeader('User-Agent', self.factory.agent)
9       
10        if self.factory.isCached:
11            self.sendHeader('If-None-Match', self.factory.headers.get("etag", '') )
12            self.sendHeader('If-Modified-Since',
13                    self.factory.headers.get("if-modified-since", '') )
14           
15        if self.factory.cookies:
16            l=[]
17            for cookie, cookval in self.factory.cookies.items(): 
18                l.append('%s=%s' % (cookie, cookval))
19            self.sendHeader('Cookie', '; '.join(l))
20        data = getattr(self.factory, 'postdata', None)
21        if data is not None:
22            self.sendHeader("Content-Length", str(len(data)))
23        for (key, value) in self.factory.headers.items():
24            if key.lower() != "content-length":
25                # we calculated it on our own
26                self.sendHeader(key, value)
27        self.endHeaders()
28        self.headers = {}
29       
30        if data is not None:
31            self.transport.write(data) 
32
33    def handleResponse(self, response):
34        re = self.headers.get("etag", None)
35        rl = self.headers.get("last-modified", None)
36        rd = self.headers.get("date", None)
37        if re or rl or rd:
38            cache = {'response':response}
39            if re:
40                cache.update({'etag':re})
41            if rl and rd:
42                cache.update({'if-modified-since':rl})
43            elif rd and not rl:
44                cache.update({'if-modified-since':rd})
45            self.factory.cache[self.factory.url] = cache                 
46        HTTPPageGetter.handleResponse(self, response)
47       
48    def handleStatus_304(self):
49        cache_entry = self.factory.cache.get(self.factory.url, None)
50        if not cache_entry:
51            self.factory.noPage(
52                failure.Failure(
53                    error.Error(
54                        self.status, self.message, "Page missing in cache")))
55            self.transport.loseConnection()
56        self.handleResponse(cache_entry.get('response'))
57
58
59class HTTPClientCacheFactory(HTTPClientFactory):
60
61    protocol = HTTPCacheDownloader
62    cache = {}
63    isCached = False
64
65    def __init__(self, url, method='GET', postdata=None, headers=None,
66                 agent="Twisted PageGetter", timeout=0, cookies=None,
67                 followRedirect=1):
68        headers = {}
69        cached = self.cache.get(url, None)
70        if cached:
71            self.isCached = True
72            etag = cached.get('etag', None)
73            if_modified_since = cached.get('if-modified-since', None)
74            if etag:
75                headers.setdefault('etag', etag)
76            if if_modified_since:
77                headers.setdefault('if-modified-since', if_modified_since)
78        else:
79            self.isCached = False
80
81        HTTPClientFactory.__init__(self, url=url, method=method,
82                postdata=postdata, headers=headers, agent=agent,
83                timeout=timeout, cookies=cookies, followRedirect=followRedirect)
84        self.deferred = defer.Deferred()
85
86def getPageCached(url, contextFactory=None, *args, **kwargs):
87    """download a web page as a string, keep a cache of already downloaded pages
88
89    Download a page. Return a deferred, which will callback with a
90    page (as a string) or errback with a description of the error.
91
92    See HTTPClientCacheFactory to see what extra args can be passed.
93    """       
94    scheme, host, port, path = _parse(url)
95    factory = HTTPClientCacheFactory(url, *args, **kwargs)
96    if scheme == 'https':
97        from twisted.internet import ssl
98        if contextFactory is None:
99            contextFactory = ssl.ClientContextFactory()
100        reactor.connectSSL(host, port, factory, contextFactory)
101    else:
102        reactor.connectTCP(host, port, factory)
103    return factory.deferred
104   
105#
106# Tests for tests_webclient.py
107#
108
109    def testGetPageCached(self):
110        self.assertEquals(unittest.deferredResult(client.getPageCached(self.getURL("file"))),
111                          "0123456789")
112
113
114
115    def testTimeoutCached(self):
116        r = unittest.deferredResult(client.getPageCached(self.getURL("wait"), timeout=1.5))
117        self.assertEquals(r, 'hello!!!')
118        f = unittest.deferredError(client.getPageCached(self.getURL("wait"), timeout=0.5))
119        f.trap(defer.TimeoutError)