Ticket #3461: issue.3461.1.patch

File issue.3461.1.patch, 8.2 KB (added by steiza, 3 years ago)

Updated fix

  • twisted/web/server.py

    diff --git twisted/web/server.py twisted/web/server.py
    index 46c461a..e9b3c8b 100644
    def _addressToTuple(addr): 
    5353        return tuple(addr) 
    5454 
    5555class Request(pb.Copyable, http.Request, components.Componentized): 
     56    """ 
     57    An HTTP request. 
     58 
     59    @ivar session: This stores a session available to HTTP and HTTPS requests. 
     60    """ 
    5661    implements(iweb.IRequest) 
    5762 
    5863    site = None 
    class Request(pb.Copyable, http.Request, components.Componentized): 
    264269    ### these calls remain local 
    265270 
    266271    session = None 
     272    _secureSession = None 
     273 
     274    def getSession(self, sessionInterface=None, forceNotSecure=False): 
     275        """ 
     276        Check if there is a session cookie, and if not, create it. 
     277 
     278        By default, the cookie with be secure for HTTPS requests and not secure 
     279        for HTTP requests. If for some reason you need access to the insecure 
     280        cookie from a secure session you can set L{forceNotSecure} = True. 
     281        """ 
     282        # Make sure we aren't creating a secure session on a non-secure page 
     283        cookieString = '' 
     284        session = None 
     285 
     286        secure = self.isSecure() 
     287 
     288        if secure and forceNotSecure: 
     289            secure = False 
     290 
     291        if not secure: 
     292            cookieString = 'TWISTED_SESSION' 
     293            session = self.session 
     294 
     295        else: 
     296            cookieString = 'TWISTED_SECURE_SESSION' 
     297            session = self._secureSession 
    267298 
    268     def getSession(self, sessionInterface = None): 
    269299        # Session management 
    270         if not self.session: 
    271             cookiename = string.join(['TWISTED_SESSION'] + self.sitepath, "_") 
     300        if not session: 
     301            cookiename = string.join([cookieString] + self.sitepath, "_") 
    272302            sessionCookie = self.getCookie(cookiename) 
    273303            if sessionCookie: 
    274304                try: 
    275                     self.session = self.site.getSession(sessionCookie) 
     305                    session = self.site.getSession(sessionCookie) 
    276306                except KeyError: 
    277307                    pass 
    278308            # if it still hasn't been set, fix it up. 
    279             if not self.session: 
    280                 self.session = self.site.makeSession() 
    281                 self.addCookie(cookiename, self.session.uid, path='/') 
    282         self.session.touch() 
     309            if not session: 
     310                session = self.site.makeSession() 
     311                self.addCookie(cookiename, session.uid, path='/', 
     312                               secure=secure) 
     313 
     314        session.touch() 
     315 
     316        # Save the session to the proper place 
     317        if not secure: 
     318            self.session = session 
     319        else: 
     320            self._secureSession = session 
     321 
    283322        if sessionInterface: 
    284             return self.session.getComponent(sessionInterface) 
    285         return self.session 
     323            return session.getComponent(sessionInterface) 
     324 
     325        return session 
    286326 
    287327    def _prePathURL(self, prepath): 
    288328        port = self.getHost().port 
  • twisted/web/test/test_web.py

    diff --git twisted/web/test/test_web.py twisted/web/test/test_web.py
    index 6306a56..a9ab6be 100644
    class DummyRequest: 
    5252        while self.go: 
    5353            prod.resumeProducing() 
    5454 
     55 
    5556    def unregisterProducer(self): 
    5657        self.go = 0 
    5758 
    class DummyRequest: 
    9192        """ 
    9293        self.outgoingHeaders[name.lower()] = value 
    9394 
    94     def getSession(self): 
    95         if self.session: 
    96             return self.session 
    97         assert not self.written, "Session cannot be requested after data has been written." 
    98         self.session = self.protoSession 
    99         return self.session 
    100  
    10195 
    10296    def render(self, resource): 
    10397        """ 
    class DummyRequest: 
    122116    def write(self, data): 
    123117        self.written.append(data) 
    124118 
     119 
    125120    def notifyFinish(self): 
    126121        """ 
    127122        Return a L{Deferred} which is called back with C{None} when the request 
    class RequestTests(unittest.TestCase): 
    568563        self.assertTrue( 
    569564            verifyObject(iweb.IRequest, server.Request(DummyChannel(), True))) 
    570565 
    571  
    572     def testChildLink(self): 
     566    def test_childLink(self): 
    573567        request = server.Request(DummyChannel(), 1) 
    574568        request.gotLength(0) 
    575569        request.requestReceived('GET', '/foo/bar', 'HTTP/1.0') 
    class RequestTests(unittest.TestCase): 
    579573        request.requestReceived('GET', '/foo/bar/', 'HTTP/1.0') 
    580574        self.assertEqual(request.childLink('baz'), 'baz') 
    581575 
    582     def testPrePathURLSimple(self): 
     576    def test_prePathURLSimple(self): 
    583577        request = server.Request(DummyChannel(), 1) 
    584578        request.gotLength(0) 
    585579        request.requestReceived('GET', '/foo/bar', 'HTTP/1.0') 
    586580        request.setHost('example.com', 80) 
    587581        self.assertEqual(request.prePathURL(), 'http://example.com/foo/bar') 
    588582 
    589     def testPrePathURLNonDefault(self): 
     583    def test_prePathURLNonDefault(self): 
    590584        d = DummyChannel() 
    591585        d.transport.port = 81 
    592586        request = server.Request(d, 1) 
    class RequestTests(unittest.TestCase): 
    595589        request.requestReceived('GET', '/foo/bar', 'HTTP/1.0') 
    596590        self.assertEqual(request.prePathURL(), 'http://example.com:81/foo/bar') 
    597591 
    598     def testPrePathURLSSLPort(self): 
     592    def test_prePathURLSSLPort(self): 
    599593        d = DummyChannel() 
    600594        d.transport.port = 443 
    601595        request = server.Request(d, 1) 
    class RequestTests(unittest.TestCase): 
    604598        request.requestReceived('GET', '/foo/bar', 'HTTP/1.0') 
    605599        self.assertEqual(request.prePathURL(), 'http://example.com:443/foo/bar') 
    606600 
    607     def testPrePathURLSSLPortAndSSL(self): 
     601    def test_prePathURLSSLPortAndSSL(self): 
    608602        d = DummyChannel() 
    609603        d.transport = DummyChannel.SSL() 
    610604        d.transport.port = 443 
    class RequestTests(unittest.TestCase): 
    614608        request.requestReceived('GET', '/foo/bar', 'HTTP/1.0') 
    615609        self.assertEqual(request.prePathURL(), 'https://example.com/foo/bar') 
    616610 
    617     def testPrePathURLHTTPPortAndSSL(self): 
     611    def test_prePathURLHTTPPortAndSSL(self): 
    618612        d = DummyChannel() 
    619613        d.transport = DummyChannel.SSL() 
    620614        d.transport.port = 80 
    class RequestTests(unittest.TestCase): 
    624618        request.requestReceived('GET', '/foo/bar', 'HTTP/1.0') 
    625619        self.assertEqual(request.prePathURL(), 'https://example.com:80/foo/bar') 
    626620 
    627     def testPrePathURLSSLNonDefault(self): 
     621    def test_prePathURLSSLNonDefault(self): 
    628622        d = DummyChannel() 
    629623        d.transport = DummyChannel.SSL() 
    630624        d.transport.port = 81 
    class RequestTests(unittest.TestCase): 
    634628        request.requestReceived('GET', '/foo/bar', 'HTTP/1.0') 
    635629        self.assertEqual(request.prePathURL(), 'https://example.com:81/foo/bar') 
    636630 
    637     def testPrePathURLSetSSLHost(self): 
     631    def test_prePathURLSetSSLHost(self): 
    638632        d = DummyChannel() 
    639633        d.transport.port = 81 
    640634        request = server.Request(d, 1) 
    class RequestTests(unittest.TestCase): 
    643637        request.requestReceived('GET', '/foo/bar', 'HTTP/1.0') 
    644638        self.assertEqual(request.prePathURL(), 'https://foo.com:81/foo/bar') 
    645639 
    646  
    647640    def test_prePathURLQuoting(self): 
    648641        """ 
    649642        L{Request.prePathURL} quotes special characters in the URL segments to 
    class RequestTests(unittest.TestCase): 
    656649        request.requestReceived('GET', '/foo%2Fbar', 'HTTP/1.0') 
    657650        self.assertEqual(request.prePathURL(), 'http://example.com/foo%2Fbar') 
    658651 
     652    def test_sessionDifferentFromSecureSession(self): 
     653        """ 
     654        L{Request.session} and L{Request.secure_session} should be two separate 
     655        sessions with unique ids. 
     656        """ 
     657        d = DummyChannel() 
     658        d.transport = DummyChannel.SSL() 
     659        request = server.Request(d, 1) 
     660        request.site = server.Site('/') 
     661        request.sitepath = [] 
     662        session = request.getSession(forceNotSecure=True) 
     663        secure_session = request.getSession() 
     664 
     665        # Check that the sessions are not None 
     666        self.assertTrue(session != None) 
     667        self.assertTrue(secure_session != None) 
     668 
     669        # Check that the sessions are different 
     670        self.assertNotEqual(session.uid, secure_session.uid) 
    659671 
     672        session.expire() 
     673        secure_session.expire() 
    660674 
    661675class RootResource(resource.Resource): 
    662676    isLeaf=0