diff --git twisted/web/server.py twisted/web/server.py
index 46c461a..e9b3c8b 100644
|
|
|
|
| 53 | 53 | return tuple(addr) |
| 54 | 54 | |
| 55 | 55 | class Request(pb.Copyable, http.Request, components.Componentized): |
| | 56 | """ |
| | 57 | An HTTP request. |
| | 58 | |
| | 59 | @ivar session: This stores a session available to HTTP and HTTPS requests. |
| | 60 | """ |
| 56 | 61 | implements(iweb.IRequest) |
| 57 | 62 | |
| 58 | 63 | site = None |
| … |
… |
|
| 264 | 269 | ### these calls remain local |
| 265 | 270 | |
| 266 | 271 | session = None |
| | 272 | _secureSession = None |
| | 273 | |
| | 274 | def getSession(self, sessionInterface=None, forceNotSecure=False): |
| | 275 | """ |
| | 276 | Check if there is a session cookie, and if not, create it. |
| | 277 | |
| | 278 | By default, the cookie with be secure for HTTPS requests and not secure |
| | 279 | for HTTP requests. If for some reason you need access to the insecure |
| | 280 | cookie from a secure session you can set L{forceNotSecure} = True. |
| | 281 | """ |
| | 282 | # Make sure we aren't creating a secure session on a non-secure page |
| | 283 | cookieString = '' |
| | 284 | session = None |
| | 285 | |
| | 286 | secure = self.isSecure() |
| | 287 | |
| | 288 | if secure and forceNotSecure: |
| | 289 | secure = False |
| | 290 | |
| | 291 | if not secure: |
| | 292 | cookieString = 'TWISTED_SESSION' |
| | 293 | session = self.session |
| | 294 | |
| | 295 | else: |
| | 296 | cookieString = 'TWISTED_SECURE_SESSION' |
| | 297 | session = self._secureSession |
| 267 | 298 | |
| 268 | | def getSession(self, sessionInterface = None): |
| 269 | 299 | # Session management |
| 270 | | if not self.session: |
| 271 | | cookiename = string.join(['TWISTED_SESSION'] + self.sitepath, "_") |
| | 300 | if not session: |
| | 301 | cookiename = string.join([cookieString] + self.sitepath, "_") |
| 272 | 302 | sessionCookie = self.getCookie(cookiename) |
| 273 | 303 | if sessionCookie: |
| 274 | 304 | try: |
| 275 | | self.session = self.site.getSession(sessionCookie) |
| | 305 | session = self.site.getSession(sessionCookie) |
| 276 | 306 | except KeyError: |
| 277 | 307 | pass |
| 278 | 308 | # if it still hasn't been set, fix it up. |
| 279 | | if not self.session: |
| 280 | | self.session = self.site.makeSession() |
| 281 | | self.addCookie(cookiename, self.session.uid, path='/') |
| 282 | | self.session.touch() |
| | 309 | if not session: |
| | 310 | session = self.site.makeSession() |
| | 311 | self.addCookie(cookiename, session.uid, path='/', |
| | 312 | secure=secure) |
| | 313 | |
| | 314 | session.touch() |
| | 315 | |
| | 316 | # Save the session to the proper place |
| | 317 | if not secure: |
| | 318 | self.session = session |
| | 319 | else: |
| | 320 | self._secureSession = session |
| | 321 | |
| 283 | 322 | if sessionInterface: |
| 284 | | return self.session.getComponent(sessionInterface) |
| 285 | | return self.session |
| | 323 | return session.getComponent(sessionInterface) |
| | 324 | |
| | 325 | return session |
| 286 | 326 | |
| 287 | 327 | def _prePathURL(self, prepath): |
| 288 | 328 | port = self.getHost().port |
diff --git twisted/web/test/test_web.py twisted/web/test/test_web.py
index 6306a56..a9ab6be 100644
|
|
|
|
| 52 | 52 | while self.go: |
| 53 | 53 | prod.resumeProducing() |
| 54 | 54 | |
| | 55 | |
| 55 | 56 | def unregisterProducer(self): |
| 56 | 57 | self.go = 0 |
| 57 | 58 | |
| … |
… |
|
| 91 | 92 | """ |
| 92 | 93 | self.outgoingHeaders[name.lower()] = value |
| 93 | 94 | |
| 94 | | def getSession(self): |
| 95 | | if self.session: |
| 96 | | return self.session |
| 97 | | assert not self.written, "Session cannot be requested after data has been written." |
| 98 | | self.session = self.protoSession |
| 99 | | return self.session |
| 100 | | |
| 101 | 95 | |
| 102 | 96 | def render(self, resource): |
| 103 | 97 | """ |
| … |
… |
|
| 122 | 116 | def write(self, data): |
| 123 | 117 | self.written.append(data) |
| 124 | 118 | |
| | 119 | |
| 125 | 120 | def notifyFinish(self): |
| 126 | 121 | """ |
| 127 | 122 | Return a L{Deferred} which is called back with C{None} when the request |
| … |
… |
|
| 568 | 563 | self.assertTrue( |
| 569 | 564 | verifyObject(iweb.IRequest, server.Request(DummyChannel(), True))) |
| 570 | 565 | |
| 571 | | |
| 572 | | def testChildLink(self): |
| | 566 | def test_childLink(self): |
| 573 | 567 | request = server.Request(DummyChannel(), 1) |
| 574 | 568 | request.gotLength(0) |
| 575 | 569 | request.requestReceived('GET', '/foo/bar', 'HTTP/1.0') |
| … |
… |
|
| 579 | 573 | request.requestReceived('GET', '/foo/bar/', 'HTTP/1.0') |
| 580 | 574 | self.assertEqual(request.childLink('baz'), 'baz') |
| 581 | 575 | |
| 582 | | def testPrePathURLSimple(self): |
| | 576 | def test_prePathURLSimple(self): |
| 583 | 577 | request = server.Request(DummyChannel(), 1) |
| 584 | 578 | request.gotLength(0) |
| 585 | 579 | request.requestReceived('GET', '/foo/bar', 'HTTP/1.0') |
| 586 | 580 | request.setHost('example.com', 80) |
| 587 | 581 | self.assertEqual(request.prePathURL(), 'http://example.com/foo/bar') |
| 588 | 582 | |
| 589 | | def testPrePathURLNonDefault(self): |
| | 583 | def test_prePathURLNonDefault(self): |
| 590 | 584 | d = DummyChannel() |
| 591 | 585 | d.transport.port = 81 |
| 592 | 586 | request = server.Request(d, 1) |
| … |
… |
|
| 595 | 589 | request.requestReceived('GET', '/foo/bar', 'HTTP/1.0') |
| 596 | 590 | self.assertEqual(request.prePathURL(), 'http://example.com:81/foo/bar') |
| 597 | 591 | |
| 598 | | def testPrePathURLSSLPort(self): |
| | 592 | def test_prePathURLSSLPort(self): |
| 599 | 593 | d = DummyChannel() |
| 600 | 594 | d.transport.port = 443 |
| 601 | 595 | request = server.Request(d, 1) |
| … |
… |
|
| 604 | 598 | request.requestReceived('GET', '/foo/bar', 'HTTP/1.0') |
| 605 | 599 | self.assertEqual(request.prePathURL(), 'http://example.com:443/foo/bar') |
| 606 | 600 | |
| 607 | | def testPrePathURLSSLPortAndSSL(self): |
| | 601 | def test_prePathURLSSLPortAndSSL(self): |
| 608 | 602 | d = DummyChannel() |
| 609 | 603 | d.transport = DummyChannel.SSL() |
| 610 | 604 | d.transport.port = 443 |
| … |
… |
|
| 614 | 608 | request.requestReceived('GET', '/foo/bar', 'HTTP/1.0') |
| 615 | 609 | self.assertEqual(request.prePathURL(), 'https://example.com/foo/bar') |
| 616 | 610 | |
| 617 | | def testPrePathURLHTTPPortAndSSL(self): |
| | 611 | def test_prePathURLHTTPPortAndSSL(self): |
| 618 | 612 | d = DummyChannel() |
| 619 | 613 | d.transport = DummyChannel.SSL() |
| 620 | 614 | d.transport.port = 80 |
| … |
… |
|
| 624 | 618 | request.requestReceived('GET', '/foo/bar', 'HTTP/1.0') |
| 625 | 619 | self.assertEqual(request.prePathURL(), 'https://example.com:80/foo/bar') |
| 626 | 620 | |
| 627 | | def testPrePathURLSSLNonDefault(self): |
| | 621 | def test_prePathURLSSLNonDefault(self): |
| 628 | 622 | d = DummyChannel() |
| 629 | 623 | d.transport = DummyChannel.SSL() |
| 630 | 624 | d.transport.port = 81 |
| … |
… |
|
| 634 | 628 | request.requestReceived('GET', '/foo/bar', 'HTTP/1.0') |
| 635 | 629 | self.assertEqual(request.prePathURL(), 'https://example.com:81/foo/bar') |
| 636 | 630 | |
| 637 | | def testPrePathURLSetSSLHost(self): |
| | 631 | def test_prePathURLSetSSLHost(self): |
| 638 | 632 | d = DummyChannel() |
| 639 | 633 | d.transport.port = 81 |
| 640 | 634 | request = server.Request(d, 1) |
| … |
… |
|
| 643 | 637 | request.requestReceived('GET', '/foo/bar', 'HTTP/1.0') |
| 644 | 638 | self.assertEqual(request.prePathURL(), 'https://foo.com:81/foo/bar') |
| 645 | 639 | |
| 646 | | |
| 647 | 640 | def test_prePathURLQuoting(self): |
| 648 | 641 | """ |
| 649 | 642 | L{Request.prePathURL} quotes special characters in the URL segments to |
| … |
… |
|
| 656 | 649 | request.requestReceived('GET', '/foo%2Fbar', 'HTTP/1.0') |
| 657 | 650 | self.assertEqual(request.prePathURL(), 'http://example.com/foo%2Fbar') |
| 658 | 651 | |
| | 652 | def test_sessionDifferentFromSecureSession(self): |
| | 653 | """ |
| | 654 | L{Request.session} and L{Request.secure_session} should be two separate |
| | 655 | sessions with unique ids. |
| | 656 | """ |
| | 657 | d = DummyChannel() |
| | 658 | d.transport = DummyChannel.SSL() |
| | 659 | request = server.Request(d, 1) |
| | 660 | request.site = server.Site('/') |
| | 661 | request.sitepath = [] |
| | 662 | session = request.getSession(forceNotSecure=True) |
| | 663 | secure_session = request.getSession() |
| | 664 | |
| | 665 | # Check that the sessions are not None |
| | 666 | self.assertTrue(session != None) |
| | 667 | self.assertTrue(secure_session != None) |
| | 668 | |
| | 669 | # Check that the sessions are different |
| | 670 | self.assertNotEqual(session.uid, secure_session.uid) |
| 659 | 671 | |
| | 672 | session.expire() |
| | 673 | secure_session.expire() |
| 660 | 674 | |
| 661 | 675 | class RootResource(resource.Resource): |
| 662 | 676 | isLeaf=0 |