root/tags/releases/twisted-8.2.0/twisted/protocols/basic.py

Revision 25267, 15.9 KB (checked in by exarkun, 22 months ago)

Merge amp-key-enforcer-3478

Author: exarkun
Reviewer: glyph
Fixes: #3478

Change BinaryBoxProtocol so that as soon as it receives a key length
prefix larger than 255 it closes the connection and reports a protocol
violation as the reason for any failures. Keys are restricted to 255
or fewer bytes by the protocol; early disconnection for violation of this
simplifies tracking down bugs in protocol implementations.

Line 
1# -*- test-case-name: twisted.test.test_protocols -*-
2# Copyright (c) 2001-2008 Twisted Matrix Laboratories.
3# See LICENSE for details.
4
5
6"""
7Basic protocols, such as line-oriented, netstring, and int prefixed strings.
8
9Maintainer: Itamar Shtull-Trauring
10"""
11
12# System imports
13import re
14import struct
15
16from zope.interface import implements
17
18# Twisted imports
19from twisted.internet import protocol, defer, interfaces, error
20from twisted.python import log
21
22LENGTH, DATA, COMMA = range(3)
23NUMBER = re.compile('(\d*)(:?)')
24DEBUG = 0
25
26class NetstringParseError(ValueError):
27    """The incoming data is not in valid Netstring format."""
28    pass
29
30
31class NetstringReceiver(protocol.Protocol):
32    """This uses djb's Netstrings protocol to break up the input into strings.
33
34    Each string makes a callback to stringReceived, with a single
35    argument of that string.
36
37    Security features:
38        1. Messages are limited in size, useful if you don't want someone
39           sending you a 500MB netstring (change MAX_LENGTH to the maximum
40           length you wish to accept).
41        2. The connection is lost if an illegal message is received.
42    """
43
44    MAX_LENGTH = 99999
45    brokenPeer = 0
46    _readerState = LENGTH
47    _readerLength = 0
48
49    def stringReceived(self, line):
50        """
51        Override this.
52        """
53        raise NotImplementedError
54
55    def doData(self):
56        buffer,self.__data = self.__data[:int(self._readerLength)],self.__data[int(self._readerLength):]
57        self._readerLength = self._readerLength - len(buffer)
58        self.__buffer = self.__buffer + buffer
59        if self._readerLength != 0:
60            return
61        self.stringReceived(self.__buffer)
62        self._readerState = COMMA
63
64    def doComma(self):
65        self._readerState = LENGTH
66        if self.__data[0] != ',':
67            if DEBUG:
68                raise NetstringParseError(repr(self.__data))
69            else:
70                raise NetstringParseError
71        self.__data = self.__data[1:]
72
73
74    def doLength(self):
75        m = NUMBER.match(self.__data)
76        if not m.end():
77            if DEBUG:
78                raise NetstringParseError(repr(self.__data))
79            else:
80                raise NetstringParseError
81        self.__data = self.__data[m.end():]
82        if m.group(1):
83            try:
84                self._readerLength = self._readerLength * (10**len(m.group(1))) + long(m.group(1))
85            except OverflowError:
86                raise NetstringParseError, "netstring too long"
87            if self._readerLength > self.MAX_LENGTH:
88                raise NetstringParseError, "netstring too long"
89        if m.group(2):
90            self.__buffer = ''
91            self._readerState = DATA
92
93    def dataReceived(self, data):
94        self.__data = data
95        try:
96            while self.__data:
97                if self._readerState == DATA:
98                    self.doData()
99                elif self._readerState == COMMA:
100                    self.doComma()
101                elif self._readerState == LENGTH:
102                    self.doLength()
103                else:
104                    raise RuntimeError, "mode is not DATA, COMMA or LENGTH"
105        except NetstringParseError:
106            self.transport.loseConnection()
107            self.brokenPeer = 1
108
109    def sendString(self, data):
110        self.transport.write('%d:%s,' % (len(data), data))
111
112
113class SafeNetstringReceiver(NetstringReceiver):
114    """This class is deprecated, use NetstringReceiver instead.
115    """
116
117
118class LineOnlyReceiver(protocol.Protocol):
119    """A protocol that receives only lines.
120
121    This is purely a speed optimisation over LineReceiver, for the
122    cases that raw mode is known to be unnecessary.
123
124    @cvar delimiter: The line-ending delimiter to use. By default this is
125                     '\\r\\n'.
126    @cvar MAX_LENGTH: The maximum length of a line to allow (If a
127                      sent line is longer than this, the connection is dropped).
128                      Default is 16384.
129    """
130    _buffer = ''
131    delimiter = '\r\n'
132    MAX_LENGTH = 16384
133
134    def dataReceived(self, data):
135        """Translates bytes into lines, and calls lineReceived."""
136        lines  = (self._buffer+data).split(self.delimiter)
137        self._buffer = lines.pop(-1)
138        for line in lines:
139            if self.transport.disconnecting:
140                # this is necessary because the transport may be told to lose
141                # the connection by a line within a larger packet, and it is
142                # important to disregard all the lines in that packet following
143                # the one that told it to close.
144                return
145            if len(line) > self.MAX_LENGTH:
146                return self.lineLengthExceeded(line)
147            else:
148                self.lineReceived(line)
149        if len(self._buffer) > self.MAX_LENGTH:
150            return self.lineLengthExceeded(self._buffer)
151
152    def lineReceived(self, line):
153        """Override this for when each line is received.
154        """
155        raise NotImplementedError
156
157    def sendLine(self, line):
158        """Sends a line to the other end of the connection.
159        """
160        return self.transport.writeSequence((line,self.delimiter))
161
162    def lineLengthExceeded(self, line):
163        """Called when the maximum line length has been reached.
164        Override if it needs to be dealt with in some special way.
165        """
166        return error.ConnectionLost('Line length exceeded')
167
168
169class _PauseableMixin:
170    paused = False
171
172    def pauseProducing(self):
173        self.paused = True
174        self.transport.pauseProducing()
175
176    def resumeProducing(self):
177        self.paused = False
178        self.transport.resumeProducing()
179        self.dataReceived('')
180
181    def stopProducing(self):
182        self.paused = True
183        self.transport.stopProducing()
184
185
186class LineReceiver(protocol.Protocol, _PauseableMixin):
187    """A protocol that receives lines and/or raw data, depending on mode.
188
189    In line mode, each line that's received becomes a callback to
190    L{lineReceived}.  In raw data mode, each chunk of raw data becomes a
191    callback to L{rawDataReceived}.  The L{setLineMode} and L{setRawMode}
192    methods switch between the two modes.
193
194    This is useful for line-oriented protocols such as IRC, HTTP, POP, etc.
195
196    @cvar delimiter: The line-ending delimiter to use. By default this is
197                     '\\r\\n'.
198    @cvar MAX_LENGTH: The maximum length of a line to allow (If a
199                      sent line is longer than this, the connection is dropped).
200                      Default is 16384.
201    """
202    line_mode = 1
203    __buffer = ''
204    delimiter = '\r\n'
205    MAX_LENGTH = 16384
206
207    def clearLineBuffer(self):
208        """Clear buffered data."""
209        self.__buffer = ""
210
211    def dataReceived(self, data):
212        """Protocol.dataReceived.
213        Translates bytes into lines, and calls lineReceived (or
214        rawDataReceived, depending on mode.)
215        """
216        self.__buffer = self.__buffer+data
217        while self.line_mode and not self.paused:
218            try:
219                line, self.__buffer = self.__buffer.split(self.delimiter, 1)
220            except ValueError:
221                if len(self.__buffer) > self.MAX_LENGTH:
222                    line, self.__buffer = self.__buffer, ''
223                    return self.lineLengthExceeded(line)
224                break
225            else:
226                linelength = len(line)
227                if linelength > self.MAX_LENGTH:
228                    exceeded = line + self.__buffer
229                    self.__buffer = ''
230                    return self.lineLengthExceeded(exceeded)
231                why = self.lineReceived(line)
232                if why or self.transport and self.transport.disconnecting:
233                    return why
234        else:
235            if not self.paused:
236                data=self.__buffer
237                self.__buffer=''
238                if data:
239                    return self.rawDataReceived(data)
240
241    def setLineMode(self, extra=''):
242        """Sets the line-mode of this receiver.
243
244        If you are calling this from a rawDataReceived callback,
245        you can pass in extra unhandled data, and that data will
246        be parsed for lines.  Further data received will be sent
247        to lineReceived rather than rawDataReceived.
248
249        Do not pass extra data if calling this function from
250        within a lineReceived callback.
251        """
252        self.line_mode = 1
253        if extra:
254            return self.dataReceived(extra)
255
256    def setRawMode(self):
257        """Sets the raw mode of this receiver.
258        Further data received will be sent to rawDataReceived rather
259        than lineReceived.
260        """
261        self.line_mode = 0
262
263    def rawDataReceived(self, data):
264        """Override this for when raw data is received.
265        """
266        raise NotImplementedError
267
268    def lineReceived(self, line):
269        """Override this for when each line is received.
270        """
271        raise NotImplementedError
272
273    def sendLine(self, line):
274        """Sends a line to the other end of the connection.
275        """
276        return self.transport.write(line + self.delimiter)
277
278    def lineLengthExceeded(self, line):
279        """Called when the maximum line length has been reached.
280        Override if it needs to be dealt with in some special way.
281
282        The argument 'line' contains the remainder of the buffer, starting
283        with (at least some part) of the line which is too long. This may
284        be more than one line, or may be only the initial portion of the
285        line.
286        """
287        return self.transport.loseConnection()
288
289
290class StringTooLongError(AssertionError):
291    """
292    Raised when trying to send a string too long for a length prefixed
293    protocol.
294    """
295
296
297class IntNStringReceiver(protocol.Protocol, _PauseableMixin):
298    """
299    Generic class for length prefixed protocols.
300
301    @ivar recvd: buffer holding received data when splitted.
302    @type recvd: C{str}
303
304    @ivar structFormat: format used for struct packing/unpacking. Define it in
305        subclass.
306    @type structFormat: C{str}
307
308    @ivar prefixLength: length of the prefix, in bytes. Define it in subclass,
309        using C{struct.calcsize(structFormat)}
310    @type prefixLength: C{int}
311    """
312    MAX_LENGTH = 99999
313    recvd = ""
314
315    def stringReceived(self, msg):
316        """
317        Override this.
318        """
319        raise NotImplementedError
320
321
322    def lengthLimitExceeded(self, length):
323        """
324        Callback invoked when a length prefix greater than C{MAX_LENGTH} is
325        received.  The default implementation disconnects the transport.
326        Override this.
327
328        @param length: The length prefix which was received.
329        @type length: C{int}
330        """
331        self.transport.loseConnection()
332
333
334    def dataReceived(self, recd):
335        """
336        Convert int prefixed strings into calls to stringReceived.
337        """
338        self.recvd = self.recvd + recd
339        while len(self.recvd) >= self.prefixLength and not self.paused:
340            length ,= struct.unpack(
341                self.structFormat, self.recvd[:self.prefixLength])
342            if length > self.MAX_LENGTH:
343                self.lengthLimitExceeded(length)
344                return
345            if len(self.recvd) < length + self.prefixLength:
346                break
347            packet = self.recvd[self.prefixLength:length + self.prefixLength]
348            self.recvd = self.recvd[length + self.prefixLength:]
349            self.stringReceived(packet)
350
351    def sendString(self, data):
352        """
353        Send an prefixed string to the other end of the connection.
354
355        @type data: C{str}
356        """
357        if len(data) >= 2 ** (8 * self.prefixLength):
358            raise StringTooLongError(
359                "Try to send %s bytes whereas maximum is %s" % (
360                len(data), 2 ** (8 * self.prefixLength)))
361        self.transport.write(struct.pack(self.structFormat, len(data)) + data)
362
363
364class Int32StringReceiver(IntNStringReceiver):
365    """
366    A receiver for int32-prefixed strings.
367
368    An int32 string is a string prefixed by 4 bytes, the 32-bit length of
369    the string encoded in network byte order.
370
371    This class publishes the same interface as NetstringReceiver.
372    """
373    structFormat = "!I"
374    prefixLength = struct.calcsize(structFormat)
375
376
377class Int16StringReceiver(IntNStringReceiver):
378    """
379    A receiver for int16-prefixed strings.
380
381    An int16 string is a string prefixed by 2 bytes, the 16-bit length of
382    the string encoded in network byte order.
383
384    This class publishes the same interface as NetstringReceiver.
385    """
386    structFormat = "!H"
387    prefixLength = struct.calcsize(structFormat)
388
389
390class Int8StringReceiver(IntNStringReceiver):
391    """
392    A receiver for int8-prefixed strings.
393
394    An int8 string is a string prefixed by 1 byte, the 8-bit length of
395    the string.
396
397    This class publishes the same interface as NetstringReceiver.
398    """
399    structFormat = "!B"
400    prefixLength = struct.calcsize(structFormat)
401
402
403class StatefulStringProtocol:
404    """
405    A stateful string protocol.
406
407    This is a mixin for string protocols (Int32StringReceiver,
408    NetstringReceiver) which translates stringReceived into a callback
409    (prefixed with 'proto_') depending on state.
410
411    The state 'done' is special; if a proto_* method returns it, the
412    connection will be closed immediately.
413    """
414
415    state = 'init'
416
417    def stringReceived(self,string):
418        """Choose a protocol phase function and call it.
419
420        Call back to the appropriate protocol phase; this begins with
421        the function proto_init and moves on to proto_* depending on
422        what each proto_* function returns.  (For example, if
423        self.proto_init returns 'foo', then self.proto_foo will be the
424        next function called when a protocol message is received.
425        """
426        try:
427            pto = 'proto_'+self.state
428            statehandler = getattr(self,pto)
429        except AttributeError:
430            log.msg('callback',self.state,'not found')
431        else:
432            self.state = statehandler(string)
433            if self.state == 'done':
434                self.transport.loseConnection()
435
436class FileSender:
437    """A producer that sends the contents of a file to a consumer.
438
439    This is a helper for protocols that, at some point, will take a
440    file-like object, read its contents, and write them out to the network,
441    optionally performing some transformation on the bytes in between.
442    """
443    implements(interfaces.IProducer)
444
445    CHUNK_SIZE = 2 ** 14
446
447    lastSent = ''
448    deferred = None
449
450    def beginFileTransfer(self, file, consumer, transform = None):
451        """Begin transferring a file
452
453        @type file: Any file-like object
454        @param file: The file object to read data from
455
456        @type consumer: Any implementor of IConsumer
457        @param consumer: The object to write data to
458
459        @param transform: A callable taking one string argument and returning
460        the same.  All bytes read from the file are passed through this before
461        being written to the consumer.
462
463        @rtype: C{Deferred}
464        @return: A deferred whose callback will be invoked when the file has been
465        completely written to the consumer.  The last byte written to the consumer
466        is passed to the callback.
467        """
468        self.file = file
469        self.consumer = consumer
470        self.transform = transform
471
472        self.deferred = deferred = defer.Deferred()
473        self.consumer.registerProducer(self, False)
474        return deferred
475
476    def resumeProducing(self):
477        chunk = ''
478        if self.file:
479            chunk = self.file.read(self.CHUNK_SIZE)
480        if not chunk:
481            self.file = None
482            self.consumer.unregisterProducer()
483            if self.deferred:
484                self.deferred.callback(self.lastSent)
485                self.deferred = None
486            return
487
488        if self.transform:
489            chunk = self.transform(chunk)
490        self.consumer.write(chunk)
491        self.lastSent = chunk[-1]
492
493    def pauseProducing(self):
494        pass
495
496    def stopProducing(self):
497        if self.deferred:
498            self.deferred.errback(Exception("Consumer asked us to stop producing"))
499            self.deferred = None
Note: See TracBrowser for help on using the browser.