[Twisted-Python] SMTP patch take 2

Anders Hammarquist iko at strakt.com
Mon Oct 7 09:18:27 EDT 2002


Hello,

Here is the promised revised SMTP patch. This adds a unit test for the
change in LineReceiver (the test can probably be improved upon, since
lineLengthExceeded is more of an exception), and a test to make sure
the '.' to finish mail data is on it's own line in SMTPClient. Plus
misc. cleanups. Comments are welcome as always.

/Anders

-------------- next part --------------
Index: twisted/protocols/basic.py
===================================================================
RCS file: /cvs/Twisted/twisted/protocols/basic.py,v
retrieving revision 1.24
diff -u -u -r1.24 basic.py
--- twisted/protocols/basic.py	1 Oct 2002 15:09:25 -0000	1.24
+++ twisted/protocols/basic.py	7 Oct 2002 13:04:01 -0000
@@ -153,12 +153,15 @@
                 line, self.__buffer = self.__buffer.split(self.delimiter, 1)
             except ValueError:
                 if len(self.__buffer) > self.MAX_LENGTH:
-                    self.transport.loseConnection()
+                    line, self.__buffer = self.__buffer, ''
+                    self.lineLengthExceeded(line)
                     return
                 break
             else:
-                if len(line) > self.MAX_LENGTH:
-                    self.transport.loseConnection()
+                linelength = len(line)
+                if linelength > self.MAX_LENGTH:
+                    line, self.__buffer = self.__buffer, ''
+                    self.lineLengthExceeded(line)
                     return
                 self.lineReceived(line)
                 if self.transport.disconnecting:
@@ -200,6 +203,12 @@
         """Sends a line to the other end of the connection.
         """
         self.transport.write(line + self.delimiter)
+
+    def lineLengthExceeded(self, line):
+        """Called when the maximum line length has been reached.
+        Override if it needs to be dealt with in some special way.
+        """
+        self.transport.loseConnection()
 
 
 class Int32StringReceiver(protocol.Protocol):
Index: twisted/protocols/smtp.py
===================================================================
RCS file: /cvs/Twisted/twisted/protocols/smtp.py,v
retrieving revision 1.24
diff -u -u -r1.24 smtp.py
--- twisted/protocols/smtp.py	2 Oct 2002 12:38:24 -0000	1.24
+++ twisted/protocols/smtp.py	7 Oct 2002 13:04:01 -0000
@@ -18,10 +18,10 @@
 """
 
 from twisted.protocols import basic
-from twisted.internet import protocol, defer
+from twisted.internet import protocol, defer, reactor
 from twisted.python import log
 
-import os, time, string, operator
+import os, time, string, operator, re
 
 class SMTPError(Exception):
     pass
@@ -49,20 +49,104 @@
         self.deferred.errback(arg)
         self.done = 1
 
+class AddressError(SMTPError):
+    "Parse error in address"
+    pass
+
+# Character classes for parsing addresses
+atom = r"-A-Za-z0-9!#$%&'*+/=?^_`{|}~"
+
+class Address:
+    """Parse and hold an RFC 2821 address.
+
+    Source routes are stipped and ignored, UUCP-style bang-paths
+    and %-style routing are not parsed.
+    """
+
+    tstring = re.compile(r'''( # A string of
+                          (?:"[^"]*" # quoted string
+			  |\\.i # backslash-escaped characted
+			  |[''' + string.replace(atom,'#',r'\#')
+                          + r'''] # atom character
+			  )+|.) # or any single character''',re.X)
+
+    def __init__(self, addr):
+        self.local = ''
+        self.domain = ''
+        self.addrstr = addr
+
+        # Tokenize
+        atl = filter(None,self.tstring.split(addr))
+
+        print atl
+
+        local = []
+        domain = []
+
+        while atl:
+            if atl[0] == '<':
+                if atl[-1] != '>':
+                    raise AddressError, "Unbalanced <>"
+                atl = atl[1:-1]
+            elif atl[0] == '@':
+                atl = atl[1:]
+                if not local:
+                    # Source route
+                    while atl and atl[0] != ':':
+                        # remove it
+                        atl = atl[1:]
+                    if not atl:
+                        raise AddressError, "Malformed source route"
+                    atl = atl[1:] # remove :
+                elif domain:
+                    raise AddressError, "Too many @"
+                else:
+                    # Now in domain
+                    domain = ['']
+            elif len(atl[0]) == 1 and atl[0] not in atom + '.':
+                raise AddressError, "Parse error at " + atl[0]
+            else:
+                if not domain:
+                    local.append(atl[0])
+                else:
+                    domain.append(atl[0])
+                atl = atl[1:]
+               
+        self.local = ''.join(local)
+        self.domain = ''.join(domain)
+
+    dequotebs = re.compile(r'\\(.)')
+    def dequote(self, addr):
+        "Remove RFC-2821 quotes from address"
+        res = []
+
+        atl = filter(None,self.tstring.split(addr))
+
+        for t in atl:
+            if t[0] == '"' and t[-1] == '"':
+                res.append(t[1:-1])
+            elif '\\' in t:
+                res.append(self.dequotebs.sub(r'\1',t))
+            else:
+                res.append(t)
+
+        return ''.join(res)
+
+    def __str__(self):
+        return '%s%s' % (self.local, self.domain and ("@" + self.domain) or "")
+
+    def __repr__(self):
+        return "%s.%s(%s)" % (self.__module__, self.__class__.__name__,
+                              repr(str(self)))
 
 class User:
 
     def __init__(self, destination, helo, protocol, orig):
-        try:
-            self.name, self.domain = string.split(destination, '@', 1)
-        except ValueError:
-            self.name = destination
-            self.domain = ''
+        self.dest = Address(destination)
         self.helo = helo
         self.protocol = protocol
         self.orig = orig
 
-
 class IMessage:
 
     def lineReceived(self, line):
@@ -83,23 +167,40 @@
 
 class SMTP(basic.LineReceiver):
 
-    def __init__(self):
+    def __init__(self, domain, timeout=600):
         self.mode = COMMAND
         self.__from = None
         self.__helo = None
-        self.__to = ()
+        self.__to = []
+        self.timeout = timeout
+        self.host = domain
+
+    def timedout(self):
+        self.sendCode(421, '%s Timeout. Try talking faster next time!' %
+                      self.host)
+        self.transport.loseConnection()
 
     def connectionMade(self):
-        self.sendCode(220, 'Spammers beware, your ass is on fire')
+        self.sendCode(220, '%s Spammers beware, your ass is on fire' %
+                      self.host)
+        if self.timeout:
+            self.timeoutID = reactor.callLater(self.timeout, self.timedout)
 
     def sendCode(self, code, message=''):
         "Send an SMTP code with a message."
         self.transport.write('%d %s\r\n' % (code, message))
 
     def lineReceived(self, line):
+        if self.timeout:
+            self.timeoutID.cancel()
+            self.timeoutID = reactor.callLater(self.timeout, self.timedout)
+
         if self.mode is DATA:
             return self.dataLineReceived(line)
-        command = string.split(line, None, 1)[0]
+        if line:
+            command = string.split(line, None, 1)[0]
+        else:
+            command = ''
         method = getattr(self, 'do_'+string.upper(command), None)
         if method is None:
             method = self.do_UNKNOWN
@@ -107,21 +208,65 @@
             line = line[len(command):]
         return method(string.strip(line))
 
+    def lineLengthExceeded(self, line):
+        if self.mode is DATA:
+            for message in self.__messages:
+                message.connectionLost()
+            self.mode = COMMAND
+            del self.__messages
+        self.sendCode(500, 'Line too long')
+
+    def rawDataReceived(self, data):
+        "Throw away rest of long line"
+        rest = string.split(data, '\r\n', 1)
+        if len(rest) == 2:
+            self.setLineMode(self.rest[1])
+
     def do_UNKNOWN(self, rest):
-        self.sendCode(502, 'Command not implemented')
+        self.sendCode(500, 'Command not implemented')
 
     def do_HELO(self, rest):
-        self.__helo = rest
-        self.sendCode(250, 'Nice to meet you')
+        peer = self.transport.getPeer()[1]
+        self.__helo = (rest, peer)
+        self.sendCode(250, '%s Hello %s, nice to meet you' % (self.host, peer))
 
     def do_QUIT(self, rest):
         self.sendCode(221, 'See you later')
         self.transport.loseConnection()
 
+    # A string of quoted strings, backslash-escaped character or
+    # atom characters + '@.,:'
+    qstring = r'("[^"]*"|\\.|[' + string.replace(atom,'#',r'\#') + r'@.,:])+'
+
+    mail_re = re.compile(r'''\s*FROM:\s*(?P<path><> # Empty <>
+                         |<''' + qstring + r'''> # <addr>
+			 |''' + qstring + r''' # addr
+			 )\s*(\s(?P<opts>.*))? # Optional WS + ESMTP options
+			 $''',re.I|re.X)
+    rcpt_re = re.compile(r'\s*TO:\s*(?P<path><' + qstring + r'''> # <addr>
+                         |''' + qstring + r''' # addr
+			 )\s*(\s(?P<opts>.*))? # Optional WS + ESMTP options
+			 $''',re.I|re.X)
+
     def do_MAIL(self, rest):
-        from_ = rest[len("MAIL:<"):-len(">")]
-        self.validateFrom(self.__helo, from_, self._fromValid,
-                                              self._fromInvalid)
+        if self.__from:
+            self.sendCode(503,"Only one sender per message, please")
+            return
+        # Clear old recipient list
+        self.__to = []
+        m = self.mail_re.match(rest)
+        if not m:
+            self.sendCode(501, "Syntax error")
+            return
+
+        try:
+            addr = Address(m.group('path'))
+        except AddressError, e:
+            self.sendCode(553, str(e))
+            return
+            
+        self.validateFrom(self.__helo, addr, self._fromValid,
+                          self._fromInvalid)
 
     def _fromValid(self, from_):
         self.__from = from_
@@ -131,12 +276,24 @@
         self.sendCode(550, 'No mail for you!')
 
     def do_RCPT(self, rest):
-        to = rest[len("TO:<"):-len(">")]
-        user = User(to, self.__helo, self, self.__from)
+        if not self.__from:
+            self.sendCode(503, "Must have sender before recpient")
+            return
+        m = self.rcpt_re.match(rest)
+        if not m:
+            self.sendCode(501, "Syntax error")
+            return
+
+        try:
+            user = User(m.group('path'), self.__helo, self, self.__from)
+        except AddressError, e:
+            self.sendCode(553, str(e))
+            return
+            
         self.validateTo(user, self._toValid, self._toInvalid)
 
     def _toValid(self, to):
-        self.__to = self.__to + (to,)
+        self.__to.append(to)
         self.sendCode(250, 'Address recognized')
 
     def _toInvalid(self, to):
@@ -144,22 +301,30 @@
 
     def do_DATA(self, rest):
         if self.__from is None or not self.__to:  
-            self.sendCode(550, 'Must have valid receiver and originator')
+            self.sendCode(503, 'Must have valid receiver and originator')
             return
         self.mode = DATA
         helo, origin, recipients = self.__helo, self.__from, self.__to
         self.__from = None
-        self.__to = ()
+        self.__to = []
         self.__messages = self.startMessage(recipients)
+        for message in self.__messages:
+            message.lineReceived(self.receivedHeader(helo, origin, recipients))
         self.sendCode(354, 'Continue')
 
     def connectionLost(self, reason):
+        # self.sendCode(421, 'Dropping connection.') # This does nothing...
+        # Ideally, if we (rather than the other side) lose the connection,
+        # we should be able to tell the other side that we are going away.
+        # RFC-2821 requires that we try.
         if self.mode is DATA:
             for message in self.__messages:
                 message.connectionLost()
+            del self.__messages
 
     def do_RSET(self, rest):
-        self.__init__()
+        self.__from = None
+        self.__to = []
         self.sendCode(250, 'I remember nothing.')
 
     def dataLineReceived(self, line):
@@ -177,6 +342,7 @@
                     deferred = message.eomReceived()
                     deferred.addCallback(ndeferred.callback)
                     deferred.addErrback(ndeferred.errback)
+                del self.__messages
                 return
             line = line[1:]
         for message in self.__messages:
@@ -189,7 +355,14 @@
         self.sendCode(550, 'Could not send e-mail')
 
     # overridable methods:
+    def receivedHeader(self, helo, origin, recipents):
+        return "Received: From %s ([%s]) by %s; %s" % (
+            helo[0], helo[1], self.host,
+            time.strftime("%a, %d %b %Y %H:%M:%S +0000", time.gmtime()))
+    
     def validateFrom(self, helo, origin, success, failure):
+        if not self.__helo:
+            self.sendCode(503,"Who are you? Say HELO first");
         success(origin)
 
     def validateTo(self, user, success, failure):
@@ -265,6 +438,7 @@
         
     def smtpCode_354_data(self, line):
         self.mailFile = self.getMailData()
+        self.lastsent = ''
         self.transport.registerProducer(self, 0)
 
     def smtpCode_250_afterData(self, line):
@@ -277,11 +451,18 @@
         chunk = self.mailFile.read(8192)
         if not chunk:
             self.transport.unregisterProducer()
-            self.sendLine('.')
+            if self.lastsent != '\n':
+                line = '\r\n.'
+            else:
+                line = '.'
+            self.sendLine(line)
             self.state = 'afterData'
+            return
 
         chunk = string.replace(chunk, "\n", "\r\n")
+        chunk = string.replace(chunk, "\r\n.", "\r\n..")
         self.transport.write(chunk)
+        self.lastsent = chunk[-1]
 
     def pauseProducing(self):
         pass
Index: twisted/test/test_protocols.py
===================================================================
RCS file: /cvs/Twisted/twisted/test/test_protocols.py,v
retrieving revision 1.17
diff -u -u -r1.17 test_protocols.py
--- twisted/test/test_protocols.py	23 Sep 2002 08:51:29 -0000	1.17
+++ twisted/test/test_protocols.py	7 Oct 2002 13:04:02 -0000
@@ -33,11 +33,13 @@
 class LineTester(basic.LineReceiver):
 
     delimiter = '\n'
+    MAX_LENGTH = 64
 
     def connectionMade(self):
         self.received = []
 
     def lineReceived(self, line):
+        print self.MAX_LENGTH, len(line)
         self.received.append(line)
         if line == '':
             self.setRawMode()
@@ -51,6 +53,10 @@
         if self.length == 0:
             self.setLineMode(rest)
 
+    def lineLengthExceeded(self, line):
+        if len(line) > self.MAX_LENGTH+1:
+            self.setLineMode(line[self.MAX_LENGTH+1:])
+
 class WireTestCase(unittest.TestCase):
 
     def testEcho(self):
@@ -103,13 +109,14 @@
 012345678len 0
 foo 5
 
+1234567890123456789012345678901234567890123456789012345678901234567890
 len 1
 
 a'''
 
     output = ['len 10', '0123456789', 'len 5', '1234\n',
               'len 20', 'foo 123', '0123456789\n012345678',
-              'len 0', 'foo 5', '', 'len 1', 'a']
+              'len 0', 'foo 5', '', '67890', 'len 1', 'a']
 
     def testBuffer(self):
         for packet_size in range(1, 10):
@@ -175,3 +182,6 @@
             r.dataReceived(s)
             if not r.brokenPeer:
                 raise AssertionError("connection wasn't closed on illegal netstring %s" % repr(s))
+
+if __name__ == '__main__':
+    unittest.main()
-------------- next part --------------
-- 
 -- Of course I'm crazy, but that doesn't mean I'm wrong.
Anders Hammarquist                                  | iko at strakt.com
AB Strakt                                           | Tel: +46 31 749 08 80
G|teborg, Sweden.           RADIO: SM6XMM and N2JGL | Fax: +46 31 749 08 81


More information about the Twisted-Python mailing list