Ticket #6927: 6927-4.diff

File 6927-4.diff, 11.2 KB (added by Adi Roiban, 7 years ago)
  • twisted/web/http.py

    diff --git twisted/web/http.py twisted/web/http.py
    index e50be47..d100712 100644
    class Request: 
    817817                        # content-dispostion headers in multipart/form-data
    818818                        # parts, so we catch the exception and tell the client
    819819                        # it was a bad request.
    820                         self.channel.transport.write(
    821                                 b"HTTP/1.1 400 Bad Request\r\n\r\n")
    822                         self.channel.transport.loseConnection()
     820                        _respondToBadRequestAndDisconnect(
     821                            self.channel.transport)
    823822                        return
    824823                    raise
    825824            self.content.seek(0, 0)
    class HTTPChannel(basic.LineReceiver, policies.TimeoutMixin): 
    15891588    """
    15901589    A receiver for HTTP requests.
    15911590
    1592     @ivar _transferDecoder: C{None} or an instance of
    1593         L{_ChunkedTransferDecoder} if the request body uses the I{chunked}
    1594         Transfer-Encoding.
     1591    @ivar MAX_LENGTH: Maximum length for initial request line and each line
     1592        from the header.
     1593
     1594    @ivar _transferDecoder: C{None} or a decoder instance if the request body
     1595        uses the I{chunked} Transfer-Encoding.
     1596    @type _transferDecoder: L{_ChunkedTransferDecoder}
     1597
     1598    @ivar maxHeaders: Maximum number of headers allowed per request.
     1599    @type maxHeaders: C{int}
     1600
     1601    @ivar totalHeadersSize: Maximum bytes for request line plus all headers
     1602        from the request.
     1603    @type totalHeadersSize: C{int}
     1604
     1605    @ivar _receivedHeaderSize: Bytes received so far for the header.
     1606    @type _receivedHeaderSize: C{int}
    15951607    """
    15961608
    1597     maxHeaders = 500 # max number of headers allowed per request
     1609    maxHeaders = 500
     1610    totalHeadersSize = 16384
    15981611
    15991612    length = 0
    16001613    persistent = 1
    class HTTPChannel(basic.LineReceiver, policies.TimeoutMixin): 
    16071620
    16081621    _savedTimeOut = None
    16091622    _receivedHeaderCount = 0
     1623    _receivedHeaderSize = 0
    16101624
    16111625    def __init__(self):
    16121626        # the request queue
    class HTTPChannel(basic.LineReceiver, policies.TimeoutMixin): 
    16171631    def connectionMade(self):
    16181632        self.setTimeout(self.timeOut)
    16191633
     1634
    16201635    def lineReceived(self, line):
     1636        """
     1637        Called for each line from request until the end of headers when
     1638        it enters binary mode.
     1639        """
    16211640        self.resetTimeout()
    16221641
     1642        self._receivedHeaderSize += len(line)
     1643        if (self._receivedHeaderSize > self.totalHeadersSize):
     1644            _respondToBadRequestAndDisconnect(self.transport)
     1645            return
     1646
    16231647        if self.__first_line:
    16241648            # if this connection is not persistent, drop any data which
    16251649            # the client (illegally) sent after the last request.
    class HTTPChannel(basic.LineReceiver, policies.TimeoutMixin): 
    16401664            self.__first_line = 0
    16411665            parts = line.split()
    16421666            if len(parts) != 3:
    1643                 self.transport.write(b"HTTP/1.1 400 Bad Request\r\n\r\n")
    1644                 self.transport.loseConnection()
     1667                _respondToBadRequestAndDisconnect(self.transport)
    16451668                return
    16461669            command, request, version = parts
    16471670            self._command = command
    16481671            self._path = request
    16491672            self._version = version
    16501673        elif line == b'':
     1674            # End of headers.
    16511675            if self.__header:
    16521676                self.headerReceived(self.__header)
    16531677            self.__header = ''
    class HTTPChannel(basic.LineReceiver, policies.TimeoutMixin): 
    16571681            else:
    16581682                self.setRawMode()
    16591683        elif line[0] in b' \t':
     1684            # Continuation of a multi line header.
    16601685            self.__header = self.__header + '\n' + line
     1686        # Regular header line.
     1687        # Processing of header line is delayed to allow accumulating multi
     1688        # line headers.
    16611689        else:
    16621690            if self.__header:
    16631691                self.headerReceived(self.__header)
    class HTTPChannel(basic.LineReceiver, policies.TimeoutMixin): 
    16851713            try:
    16861714                self.length = int(data)
    16871715            except ValueError:
    1688                 self.transport.write(b"HTTP/1.1 400 Bad Request\r\n\r\n")
     1716                respondToBadRequestAndDisconnect(self._transport)
    16891717                self.length = None
    1690                 self.transport.loseConnection()
    16911718                return
    16921719            self._transferDecoder = _IdentityTransferDecoder(
    16931720                self.length, self.requests[-1].handleContentChunk, self._finishRequestBody)
    class HTTPChannel(basic.LineReceiver, policies.TimeoutMixin): 
    17061733
    17071734        self._receivedHeaderCount += 1
    17081735        if self._receivedHeaderCount > self.maxHeaders:
    1709             self.transport.write(b"HTTP/1.1 400 Bad Request\r\n\r\n")
    1710             self.transport.loseConnection()
     1736            _respondToBadRequestAndDisconnect(self.transport)
     1737            return
    17111738
    17121739
    17131740    def allContentReceived(self):
    class HTTPChannel(basic.LineReceiver, policies.TimeoutMixin): 
    17181745        # reset ALL state variables, so we don't interfere with next request
    17191746        self.length = 0
    17201747        self._receivedHeaderCount = 0
     1748        self._receivedHeaderSize = 0
    17211749        self.__first_line = 1
    17221750        self._transferDecoder = None
    17231751        del self._command, self._path, self._version
    class HTTPChannel(basic.LineReceiver, policies.TimeoutMixin): 
    17361764        try:
    17371765            self._transferDecoder.dataReceived(data)
    17381766        except _MalformedChunkedDataError:
    1739             self.transport.write(b"HTTP/1.1 400 Bad Request\r\n\r\n")
    1740             self.transport.loseConnection()
     1767            _respondToBadRequestAndDisconnect(self.transport)
    17411768
    17421769
    17431770    def allHeadersReceived(self):
    class HTTPChannel(basic.LineReceiver, policies.TimeoutMixin): 
    18251852
    18261853
    18271854
     1855def _respondToBadRequestAndDisconnect(transport):
     1856    """
     1857    This is a quick and dirty way of responding to bad requests.
     1858
     1859    As described by HTTP standard we should be patient and accept the
     1860    whole request from the client, before sending a polite bad request
     1861    response, even in the case when clients send tons of data.
     1862    """
     1863    transport.write(b"HTTP/1.1 400 Bad Request\r\n\r\n")
     1864    transport.loseConnection()
     1865
     1866
     1867
    18281868def _escape(s):
    18291869    """
    18301870    Return a string like python repr, but always escaped as if surrounding
  • twisted/web/test/test_http.py

    diff --git twisted/web/test/test_http.py twisted/web/test/test_http.py
    index 5a0a853..12a94f2 100644
    class ParsingTestCase(unittest.TestCase): 
    660660        self.didRequest = False
    661661
    662662
    663     def runRequest(self, httpRequest, requestClass, success=1):
     663    def runRequest(self, httpRequest, requestFactory=None, success=True,
     664                   channel=None):
     665        """
     666        Execute a web request based on plain text content.
     667
     668        @param httpRequest: Content for the request which is processed.
     669        @type httpRequest: C{bytes}
     670
     671        @param requestFactory: 2-argument callable returning a Request.
     672        @type requestFactory: C{callable}
     673
     674        @param success: Value to compare against I{self.didRequest}.
     675        @type success: C{bool}
     676
     677        @param channel: Channel instance over which the request is processed.
     678        @type channel: L{HTTPChannel}
     679
     680        @return: Returns the channel used for processing the request.
     681        @rtype: L{HTTPChannel}
     682        """
     683        if not channel:
     684            channel = http.HTTPChannel()
     685
     686        if requestFactory:
     687            channel.requestFactory = requestFactory
     688
    664689        httpRequest = httpRequest.replace(b"\n", b"\r\n")
    665         b = StringTransport()
    666         a = http.HTTPChannel()
    667         a.requestFactory = requestClass
    668         a.makeConnection(b)
     690        transport = StringTransport()
     691
     692        channel.makeConnection(transport)
    669693        # one byte at a time, to stress it.
    670694        for byte in iterbytes(httpRequest):
    671             if a.transport.disconnecting:
     695            if channel.transport.disconnecting:
    672696                break
    673             a.dataReceived(byte)
    674         a.connectionLost(IOError("all done"))
     697            channel.dataReceived(byte)
     698        channel.connectionLost(IOError("all done"))
     699
    675700        if success:
    676701            self.assertTrue(self.didRequest)
    677702        else:
    678703            self.assertFalse(self.didRequest)
    679         return a
     704        return channel
    680705
    681706
    682707    def test_basicAuth(self):
    class ParsingTestCase(unittest.TestCase): 
    803828            b'\r\n')
    804829
    805830
     831    def test_headersTooBigInitialCommand(self):
     832        """
     833        Enforces a limit of C{HTTPChannel.totalHeadersSize}
     834        on the size of headers received per request starting from initial
     835        command line.
     836        """
     837        channel = http.HTTPChannel()
     838        channel.totalHeadersSize = 10
     839        httpRequest = b'GET /path/longer/than/10 HTTP/1.1\n'
     840
     841        channel = self.runRequest(
     842            httpRequest=httpRequest, channel=channel, success=False)
     843
     844        self.assertEqual(
     845            channel.transport.value(),
     846            b"HTTP/1.1 400 Bad Request\r\n\r\n")
     847
     848
     849    def test_headersTooBigOtherHeaders(self):
     850        """
     851        Enforces a limit of C{HTTPChannel.totalHeadersSize}
     852        on the size of headers received per request counting first line
     853        and total headers.
     854        """
     855        channel = http.HTTPChannel()
     856        channel.totalHeadersSize = 40
     857        httpRequest = (
     858            b'GET /less/than/40 HTTP/1.1\n'
     859            b'Some-Header: less-than-40\n'
     860            )
     861
     862        channel = self.runRequest(
     863            httpRequest=httpRequest, channel=channel, success=False)
     864
     865        self.assertEqual(
     866            channel.transport.value(),
     867            b"HTTP/1.1 400 Bad Request\r\n\r\n")
     868
     869
     870    def test_headersTooBigPerRequest(self):
     871        """
     872        Enforces total size of headers per individual request and counter
     873        is reset at the end of each request.
     874        """
     875        class SimpleRequest(http.Request):
     876            def process(self):
     877                self.finish()
     878        channel = http.HTTPChannel()
     879        channel.totalHeadersSize = 60
     880        channel.requestFactory = SimpleRequest
     881        httpRequest = (
     882            b'GET / HTTP/1.1\n'
     883            b'Some-Header: total-less-than-60\n'
     884            b'\n'
     885            b'GET / HTTP/1.1\n'
     886            b'Some-Header: less-than-60\n'
     887            b'\n'
     888            )
     889
     890        channel = self.runRequest(
     891            httpRequest=httpRequest, channel=channel, success=False)
     892
     893        self.assertEqual(
     894            channel.transport.value(),
     895            b'HTTP/1.1 200 OK\r\n'
     896            b'Transfer-Encoding: chunked\r\n'
     897            b'\r\n'
     898            b'0\r\n'
     899            b'\r\n'
     900            b'HTTP/1.1 200 OK\r\n'
     901            b'Transfer-Encoding: chunked\r\n'
     902            b'\r\n'
     903            b'0\r\n'
     904            b'\r\n'
     905            )
     906
     907
    806908    def testCookies(self):
    807909        """
    808910        Test cookies parsing and reading.
  • new file twisted/web/topfiles/6927.misc

    diff --git twisted/web/topfiles/6927.misc twisted/web/topfiles/6927.misc
    new file mode 100644
    index 0000000..f72a31f
    - +  
     1twisted.web.http.HTTPChannel now limit the total headers size, including first command line, to 16KB.
     2 No newline at end of file