Ticket #3461: issue.3461.1.patch

File issue.3461.1.patch, 8.2 KB (added by steiza, 5 years ago)

Updated fix

  • twisted/web/server.py

    diff --git twisted/web/server.py twisted/web/server.py
    index 46c461a..e9b3c8b 100644
    def _addressToTuple(addr): 
    5353        return tuple(addr)
    5454
    5555class Request(pb.Copyable, http.Request, components.Componentized):
     56    """
     57    An HTTP request.
     58
     59    @ivar session: This stores a session available to HTTP and HTTPS requests.
     60    """
    5661    implements(iweb.IRequest)
    5762
    5863    site = None
    class Request(pb.Copyable, http.Request, components.Componentized): 
    264269    ### these calls remain local
    265270
    266271    session = None
     272    _secureSession = None
     273
     274    def getSession(self, sessionInterface=None, forceNotSecure=False):
     275        """
     276        Check if there is a session cookie, and if not, create it.
     277
     278        By default, the cookie with be secure for HTTPS requests and not secure
     279        for HTTP requests. If for some reason you need access to the insecure
     280        cookie from a secure session you can set L{forceNotSecure} = True.
     281        """
     282        # Make sure we aren't creating a secure session on a non-secure page
     283        cookieString = ''
     284        session = None
     285
     286        secure = self.isSecure()
     287
     288        if secure and forceNotSecure:
     289            secure = False
     290
     291        if not secure:
     292            cookieString = 'TWISTED_SESSION'
     293            session = self.session
     294
     295        else:
     296            cookieString = 'TWISTED_SECURE_SESSION'
     297            session = self._secureSession
    267298
    268     def getSession(self, sessionInterface = None):
    269299        # Session management
    270         if not self.session:
    271             cookiename = string.join(['TWISTED_SESSION'] + self.sitepath, "_")
     300        if not session:
     301            cookiename = string.join([cookieString] + self.sitepath, "_")
    272302            sessionCookie = self.getCookie(cookiename)
    273303            if sessionCookie:
    274304                try:
    275                     self.session = self.site.getSession(sessionCookie)
     305                    session = self.site.getSession(sessionCookie)
    276306                except KeyError:
    277307                    pass
    278308            # if it still hasn't been set, fix it up.
    279             if not self.session:
    280                 self.session = self.site.makeSession()
    281                 self.addCookie(cookiename, self.session.uid, path='/')
    282         self.session.touch()
     309            if not session:
     310                session = self.site.makeSession()
     311                self.addCookie(cookiename, session.uid, path='/',
     312                               secure=secure)
     313
     314        session.touch()
     315
     316        # Save the session to the proper place
     317        if not secure:
     318            self.session = session
     319        else:
     320            self._secureSession = session
     321
    283322        if sessionInterface:
    284             return self.session.getComponent(sessionInterface)
    285         return self.session
     323            return session.getComponent(sessionInterface)
     324
     325        return session
    286326
    287327    def _prePathURL(self, prepath):
    288328        port = self.getHost().port
  • twisted/web/test/test_web.py

    diff --git twisted/web/test/test_web.py twisted/web/test/test_web.py
    index 6306a56..a9ab6be 100644
    class DummyRequest: 
    5252        while self.go:
    5353            prod.resumeProducing()
    5454
     55
    5556    def unregisterProducer(self):
    5657        self.go = 0
    5758
    class DummyRequest: 
    9192        """
    9293        self.outgoingHeaders[name.lower()] = value
    9394
    94     def getSession(self):
    95         if self.session:
    96             return self.session
    97         assert not self.written, "Session cannot be requested after data has been written."
    98         self.session = self.protoSession
    99         return self.session
    100 
    10195
    10296    def render(self, resource):
    10397        """
    class DummyRequest: 
    122116    def write(self, data):
    123117        self.written.append(data)
    124118
     119
    125120    def notifyFinish(self):
    126121        """
    127122        Return a L{Deferred} which is called back with C{None} when the request
    class RequestTests(unittest.TestCase): 
    568563        self.assertTrue(
    569564            verifyObject(iweb.IRequest, server.Request(DummyChannel(), True)))
    570565
    571 
    572     def testChildLink(self):
     566    def test_childLink(self):
    573567        request = server.Request(DummyChannel(), 1)
    574568        request.gotLength(0)
    575569        request.requestReceived('GET', '/foo/bar', 'HTTP/1.0')
    class RequestTests(unittest.TestCase): 
    579573        request.requestReceived('GET', '/foo/bar/', 'HTTP/1.0')
    580574        self.assertEqual(request.childLink('baz'), 'baz')
    581575
    582     def testPrePathURLSimple(self):
     576    def test_prePathURLSimple(self):
    583577        request = server.Request(DummyChannel(), 1)
    584578        request.gotLength(0)
    585579        request.requestReceived('GET', '/foo/bar', 'HTTP/1.0')
    586580        request.setHost('example.com', 80)
    587581        self.assertEqual(request.prePathURL(), 'http://example.com/foo/bar')
    588582
    589     def testPrePathURLNonDefault(self):
     583    def test_prePathURLNonDefault(self):
    590584        d = DummyChannel()
    591585        d.transport.port = 81
    592586        request = server.Request(d, 1)
    class RequestTests(unittest.TestCase): 
    595589        request.requestReceived('GET', '/foo/bar', 'HTTP/1.0')
    596590        self.assertEqual(request.prePathURL(), 'http://example.com:81/foo/bar')
    597591
    598     def testPrePathURLSSLPort(self):
     592    def test_prePathURLSSLPort(self):
    599593        d = DummyChannel()
    600594        d.transport.port = 443
    601595        request = server.Request(d, 1)
    class RequestTests(unittest.TestCase): 
    604598        request.requestReceived('GET', '/foo/bar', 'HTTP/1.0')
    605599        self.assertEqual(request.prePathURL(), 'http://example.com:443/foo/bar')
    606600
    607     def testPrePathURLSSLPortAndSSL(self):
     601    def test_prePathURLSSLPortAndSSL(self):
    608602        d = DummyChannel()
    609603        d.transport = DummyChannel.SSL()
    610604        d.transport.port = 443
    class RequestTests(unittest.TestCase): 
    614608        request.requestReceived('GET', '/foo/bar', 'HTTP/1.0')
    615609        self.assertEqual(request.prePathURL(), 'https://example.com/foo/bar')
    616610
    617     def testPrePathURLHTTPPortAndSSL(self):
     611    def test_prePathURLHTTPPortAndSSL(self):
    618612        d = DummyChannel()
    619613        d.transport = DummyChannel.SSL()
    620614        d.transport.port = 80
    class RequestTests(unittest.TestCase): 
    624618        request.requestReceived('GET', '/foo/bar', 'HTTP/1.0')
    625619        self.assertEqual(request.prePathURL(), 'https://example.com:80/foo/bar')
    626620
    627     def testPrePathURLSSLNonDefault(self):
     621    def test_prePathURLSSLNonDefault(self):
    628622        d = DummyChannel()
    629623        d.transport = DummyChannel.SSL()
    630624        d.transport.port = 81
    class RequestTests(unittest.TestCase): 
    634628        request.requestReceived('GET', '/foo/bar', 'HTTP/1.0')
    635629        self.assertEqual(request.prePathURL(), 'https://example.com:81/foo/bar')
    636630
    637     def testPrePathURLSetSSLHost(self):
     631    def test_prePathURLSetSSLHost(self):
    638632        d = DummyChannel()
    639633        d.transport.port = 81
    640634        request = server.Request(d, 1)
    class RequestTests(unittest.TestCase): 
    643637        request.requestReceived('GET', '/foo/bar', 'HTTP/1.0')
    644638        self.assertEqual(request.prePathURL(), 'https://foo.com:81/foo/bar')
    645639
    646 
    647640    def test_prePathURLQuoting(self):
    648641        """
    649642        L{Request.prePathURL} quotes special characters in the URL segments to
    class RequestTests(unittest.TestCase): 
    656649        request.requestReceived('GET', '/foo%2Fbar', 'HTTP/1.0')
    657650        self.assertEqual(request.prePathURL(), 'http://example.com/foo%2Fbar')
    658651
     652    def test_sessionDifferentFromSecureSession(self):
     653        """
     654        L{Request.session} and L{Request.secure_session} should be two separate
     655        sessions with unique ids.
     656        """
     657        d = DummyChannel()
     658        d.transport = DummyChannel.SSL()
     659        request = server.Request(d, 1)
     660        request.site = server.Site('/')
     661        request.sitepath = []
     662        session = request.getSession(forceNotSecure=True)
     663        secure_session = request.getSession()
     664
     665        # Check that the sessions are not None
     666        self.assertTrue(session != None)
     667        self.assertTrue(secure_session != None)
     668
     669        # Check that the sessions are different
     670        self.assertNotEqual(session.uid, secure_session.uid)
    659671
     672        session.expire()
     673        secure_session.expire()
    660674
    661675class RootResource(resource.Resource):
    662676    isLeaf=0