[Twisted-Python] SMTP

Anders Hammarquist iko at strakt.com
Thu Oct 31 13:36:17 MST 2002


SMTP server patch, take 3

This now includes a tested unit test.

Feedback is welcome, as always.

/Anders

-------------- next part --------------
Index: twisted/protocols/basic.py
===================================================================
RCS file: /cvs/Twisted/twisted/protocols/basic.py,v
retrieving revision 1.25
diff -u -u -r1.25 basic.py
--- twisted/protocols/basic.py	18 Oct 2002 06:46:52 -0000	1.25
+++ twisted/protocols/basic.py	31 Oct 2002 20:25:03 -0000
@@ -153,12 +153,15 @@
                 line, self.__buffer = self.__buffer.split(self.delimiter, 1)
             except ValueError:
                 if len(self.__buffer) > self.MAX_LENGTH:
-                    self.transport.loseConnection()
+                    line, self.__buffer = self.__buffer, ''
+                    self.lineLengthExceeded(line)
                     return
                 break
             else:
-                if len(line) > self.MAX_LENGTH:
-                    self.transport.loseConnection()
+                linelength = len(line)
+                if linelength > self.MAX_LENGTH:
+                    line, self.__buffer = self.__buffer, ''
+                    self.lineLengthExceeded(line)
                     return
                 self.lineReceived(line)
                 if self.transport.disconnecting:
@@ -200,6 +203,12 @@
         """Sends a line to the other end of the connection.
         """
         self.transport.write(line + self.delimiter)
+
+    def lineLengthExceeded(self, line):
+        """Called when the maximum line length has been reached.
+        Override if it needs to be dealt with in some special way.
+        """
+        self.transport.loseConnection()
 
 
 class Int32StringReceiver(protocol.Protocol):
Index: twisted/protocols/smtp.py
===================================================================
RCS file: /cvs/Twisted/twisted/protocols/smtp.py,v
retrieving revision 1.26
diff -u -u -r1.26 smtp.py
--- twisted/protocols/smtp.py	14 Oct 2002 13:21:06 -0000	1.26
+++ twisted/protocols/smtp.py	31 Oct 2002 20:25:04 -0000
@@ -18,10 +18,10 @@
 """
 
 from twisted.protocols import basic
-from twisted.internet import protocol, defer
+from twisted.internet import protocol, defer, reactor
 from twisted.python import log
 
-import os, time, string, operator
+import os, time, string, operator, re
 
 class SMTPError(Exception):
     pass
@@ -49,19 +49,110 @@
         self.deferred.errback(arg)
         self.done = 1
 
+class AddressError(SMTPError):
+    "Parse error in address"
+    pass
+
+# Character classes for parsing addresses
+atom = r"-A-Za-z0-9!#$%&'*+/=?^_`{|}~"
+
+class Address:
+    """Parse and hold an RFC 2821 address.
+
+    Source routes are stipped and ignored, UUCP-style bang-paths
+    and %-style routing are not parsed.
+    """
+
+    tstring = re.compile(r'''( # A string of
+                          (?:"[^"]*" # quoted string
+			  |\\.i # backslash-escaped characted
+			  |[''' + string.replace(atom,'#',r'\#')
+                          + r'''] # atom character
+			  )+|.) # or any single character''',re.X)
+
+    def __init__(self, addr):
+        self.local = ''
+        self.domain = ''
+        self.addrstr = addr
+
+        # Tokenize
+        atl = filter(None,self.tstring.split(addr))
+
+        local = []
+        domain = []
+
+        while atl:
+            if atl[0] == '<':
+                if atl[-1] != '>':
+                    raise AddressError, "Unbalanced <>"
+                atl = atl[1:-1]
+            elif atl[0] == '@':
+                atl = atl[1:]
+                if not local:
+                    # Source route
+                    while atl and atl[0] != ':':
+                        # remove it
+                        atl = atl[1:]
+                    if not atl:
+                        raise AddressError, "Malformed source route"
+                    atl = atl[1:] # remove :
+                elif domain:
+                    raise AddressError, "Too many @"
+                else:
+                    # Now in domain
+                    domain = ['']
+            elif len(atl[0]) == 1 and atl[0] not in atom + '.':
+                raise AddressError, "Parse error at " + atl[0]
+            else:
+                if not domain:
+                    local.append(atl[0])
+                else:
+                    domain.append(atl[0])
+                atl = atl[1:]
+               
+        self.local = ''.join(local)
+        self.domain = ''.join(domain)
+
+    dequotebs = re.compile(r'\\(.)')
+    def dequote(self,addr):
+        "Remove RFC-2821 quotes from address"
+        res = []
+
+        atl = filter(None,self.tstring.split(str(addr)))
+
+        for t in atl:
+            if t[0] == '"' and t[-1] == '"':
+                res.append(t[1:-1])
+            elif '\\' in t:
+                res.append(self.dequotebs.sub(r'\1',t))
+            else:
+                res.append(t)
+
+        return ''.join(res)
+
+    def __str__(self):
+        return '%s%s' % (self.local, self.domain and ("@" + self.domain) or "")
+
+    def __repr__(self):
+        return "%s.%s(%s)" % (self.__module__, self.__class__.__name__,
+                              repr(str(self)))
 
 class User:
 
     def __init__(self, destination, helo, protocol, orig):
-        try:
-            self.name, self.domain = string.split(destination, '@', 1)
-        except ValueError:
-            self.name = destination
-            self.domain = ''
+        self.dest = Address(destination)
         self.helo = helo
         self.protocol = protocol
         self.orig = orig
 
+    def __getstate__(self):
+        return { 'dest' : self.dest,
+                 'helo' : self.helo,
+                 'protocol' : None,
+                 'orig' : self.orig }
+
+    def __str__(self):
+        return str(self.dest)
 
 class IMessage:
 
@@ -87,19 +178,39 @@
         self.mode = COMMAND
         self.__from = None
         self.__helo = None
-        self.__to = ()
+        self.__to = []
+
+    def timedout(self):
+        self.sendCode(421, '%s Timeout. Try talking faster next time!' %
+                      self.host)
+        self.transport.loseConnection()
 
     def connectionMade(self):
-        self.sendCode(220, 'Spammers beware, your ass is on fire')
+        self.host = self.factory.domain
+        if hasattr(self.factory, 'timeout'):
+            self.timeout = self.factory.timeout
+        else:
+            self.timeout = 600
+        self.sendCode(220, '%s Spammers beware, your ass is on fire' %
+                      self.host)
+        if self.timeout:
+            self.timeoutID = reactor.callLater(self.timeout, self.timedout)
 
     def sendCode(self, code, message=''):
         "Send an SMTP code with a message."
         self.transport.write('%d %s\r\n' % (code, message))
 
     def lineReceived(self, line):
+        if self.timeout:
+            self.timeoutID.cancel()
+            self.timeoutID = reactor.callLater(self.timeout, self.timedout)
+
         if self.mode is DATA:
             return self.dataLineReceived(line)
-        command = string.split(line, None, 1)[0]
+        if line:
+            command = string.split(line, None, 1)[0]
+        else:
+            command = ''
         method = getattr(self, 'do_'+string.upper(command), None)
         if method is None:
             method = self.do_UNKNOWN
@@ -107,59 +218,128 @@
             line = line[len(command):]
         return method(string.strip(line))
 
+    def lineLengthExceeded(self, line):
+        if self.mode is DATA:
+            for message in self.__messages:
+                message.connectionLost()
+            self.mode = COMMAND
+            del self.__messages
+        self.sendCode(500, 'Line too long')
+
+    def rawDataReceived(self, data):
+        "Throw away rest of long line"
+        rest = string.split(data, '\r\n', 1)
+        if len(rest) == 2:
+            self.setLineMode(self.rest[1])
+
     def do_UNKNOWN(self, rest):
-        self.sendCode(502, 'Command not implemented')
+        self.sendCode(500, 'Command not implemented')
 
     def do_HELO(self, rest):
-        self.__helo = rest
-        self.sendCode(250, 'Nice to meet you')
+        peer = self.transport.getPeer()[1]
+        self.__helo = (rest, peer)
+        self.sendCode(250, '%s Hello %s, nice to meet you' % (self.host, peer))
 
     def do_QUIT(self, rest):
         self.sendCode(221, 'See you later')
         self.transport.loseConnection()
 
+    # A string of quoted strings, backslash-escaped character or
+    # atom characters + '@.,:'
+    qstring = r'("[^"]*"|\\.|[' + string.replace(atom,'#',r'\#') + r'@.,:])+'
+
+    mail_re = re.compile(r'''\s*FROM:\s*(?P<path><> # Empty <>
+                         |<''' + qstring + r'''> # <addr>
+			 |''' + qstring + r''' # addr
+			 )\s*(\s(?P<opts>.*))? # Optional WS + ESMTP options
+			 $''',re.I|re.X)
+    rcpt_re = re.compile(r'\s*TO:\s*(?P<path><' + qstring + r'''> # <addr>
+                         |''' + qstring + r''' # addr
+			 )\s*(\s(?P<opts>.*))? # Optional WS + ESMTP options
+			 $''',re.I|re.X)
+
     def do_MAIL(self, rest):
-        from_ = rest[len("MAIL:<"):-len(">")]
-        self.validateFrom(self.__helo, from_, self._fromValid,
-                                              self._fromInvalid)
+        if self.__from:
+            self.sendCode(503,"Only one sender per message, please")
+            return
+        # Clear old recipient list
+        self.__to = []
+        m = self.mail_re.match(rest)
+        if not m:
+            self.sendCode(501, "Syntax error")
+            return
 
-    def _fromValid(self, from_):
+        try:
+            addr = Address(m.group('path'))
+        except AddressError, e:
+            self.sendCode(553, str(e))
+            return
+            
+        self.validateFrom(self.__helo, addr, self._fromValid,
+                          self._fromInvalid)
+
+    def _fromValid(self, from_, code=250, msg='From address accepted'):
         self.__from = from_
-        self.sendCode(250, 'From address accepted')
+        self.sendCode(code, msg)
 
-    def _fromInvalid(self, from_):
-        self.sendCode(550, 'No mail for you!')
+    def _fromInvalid(self, from_, code=550, msg='No mail for you!'):
+        self.sendCode(code,msg)
 
     def do_RCPT(self, rest):
-        to = rest[len("TO:<"):-len(">")]
-        user = User(to, self.__helo, self, self.__from)
-        self.validateTo(user, self._toValid, self._toInvalid)
+        if not self.__from:
+            self.sendCode(503, "Must have sender before recipient")
+            return
+        m = self.rcpt_re.match(rest)
+        if not m:
+            self.sendCode(501, "Syntax error")
+            return
 
-    def _toValid(self, to):
-        self.__to = self.__to + (to,)
-        self.sendCode(250, 'Address recognized')
+        try:
+            user = User(m.group('path'), self.__helo, self, self.__from)
+        except AddressError, e:
+            self.sendCode(553, str(e))
+            return
+            
+        self.validateTo(user, self._toValid, self._toInvalid)
 
-    def _toInvalid(self, to):
-        self.sendCode(550, 'Cannot receive for specified address')
+    def _toValid(self, to, code=250, msg='Address recognized'):
+        self.__to.append(to)
+        self.sendCode(code, msg)
+
+    def _toInvalid(self, to, code=550,
+                   msg='Cannot receive for specified address'):
+        self.sendCode(code, msg)
 
     def do_DATA(self, rest):
         if self.__from is None or not self.__to:  
-            self.sendCode(550, 'Must have valid receiver and originator')
+            self.sendCode(503, 'Must have valid receiver and originator')
             return
         self.mode = DATA
         helo, origin, recipients = self.__helo, self.__from, self.__to
         self.__from = None
-        self.__to = ()
+        self.__to = []
         self.__messages = self.startMessage(recipients)
+        self.__inheader = self.__inbody = 0
+        for message in self.__messages:
+            message.lineReceived(self.receivedHeader(helo, origin, recipients))
         self.sendCode(354, 'Continue')
 
     def connectionLost(self, reason):
+        # self.sendCode(421, 'Dropping connection.') # This does nothing...
+        # Ideally, if we (rather than the other side) lose the connection,
+        # we should be able to tell the other side that we are going away.
+        # RFC-2821 requires that we try.
         if self.mode is DATA:
-            for message in self.__messages:
-                message.connectionLost()
+            try:
+                for message in self.__messages:
+                    message.connectionLost()
+                del self.__messages
+            except AttributeError:
+                pass
 
     def do_RSET(self, rest):
-        self.__init__()
+        self.__from = None
+        self.__to = []
         self.sendCode(250, 'I remember nothing.')
 
     def dataLineReceived(self, line):
@@ -177,8 +357,24 @@
                     deferred = message.eomReceived()
                     deferred.addCallback(ndeferred.callback)
                     deferred.addErrback(ndeferred.errback)
+                del self.__messages
                 return
             line = line[1:]
+
+        # Add a blank line between the generated Received:-header
+        # and the message body if the message comes in without any
+        # headers
+        if not self.__inheader and not self.__inbody:
+            if ':' in line:
+                self.__inheaders = 1
+            elif line:
+                for message in self.__messages:
+                    message.lineReceived('')
+                self.__inbody = 1
+
+        if not line:
+            self.__inbody = 1
+        
         for message in self.__messages:
             message.lineReceived(line)
 
@@ -188,8 +384,27 @@
     def _messageNotHandled(self, _):
         self.sendCode(550, 'Could not send e-mail')
 
+    def rfc822date(self):
+        timeinfo = time.localtime()
+        if timeinfo[8]:
+            # DST
+            tz = -time.altzone
+        else:
+            tz = -time.timezone
+            
+        return "%s %+2.2d%2.2d" % (
+            time.strftime("%a, %d %b %Y %H:%M:%S", timeinfo),
+            tz / 3600, (tz / 60) % 60)
+
     # overridable methods:
+    def receivedHeader(self, helo, origin, recipents):
+        return "Received: From %s ([%s]) by %s; %s" % (
+            helo[0], helo[1], self.host, self.rfc822date())
+    
     def validateFrom(self, helo, origin, success, failure):
+        if not helo:
+            failure(origin,503,"Who are you? Say HELO first");
+            return
         success(origin)
 
     def validateTo(self, user, success, failure):
@@ -265,6 +480,7 @@
         
     def smtpCode_354_data(self, line):
         self.mailFile = self.getMailData()
+        self.lastsent = ''
         self.transport.registerProducer(self, 0)
 
     def smtpCode_250_afterData(self, line):
@@ -277,11 +493,18 @@
         chunk = self.mailFile.read(8192)
         if not chunk:
             self.transport.unregisterProducer()
-            self.sendLine('.')
+            if self.lastsent != '\n':
+                line = '\r\n.'
+            else:
+                line = '.'
+            self.sendLine(line)
             self.state = 'afterData'
+            return
 
         chunk = string.replace(chunk, "\n", "\r\n")
+        chunk = string.replace(chunk, "\r\n.", "\r\n..")
         self.transport.write(chunk)
+        self.lastsent = chunk[-1]
 
     def pauseProducing(self):
         pass
Index: twisted/test/test_protocols.py
===================================================================
RCS file: /cvs/Twisted/twisted/test/test_protocols.py,v
retrieving revision 1.17
diff -u -u -r1.17 test_protocols.py
--- twisted/test/test_protocols.py	23 Sep 2002 08:51:29 -0000	1.17
+++ twisted/test/test_protocols.py	31 Oct 2002 20:25:04 -0000
@@ -33,11 +33,13 @@
 class LineTester(basic.LineReceiver):
 
     delimiter = '\n'
+    MAX_LENGTH = 64
 
     def connectionMade(self):
         self.received = []
 
     def lineReceived(self, line):
+        print self.MAX_LENGTH, len(line)
         self.received.append(line)
         if line == '':
             self.setRawMode()
@@ -51,6 +53,10 @@
         if self.length == 0:
             self.setLineMode(rest)
 
+    def lineLengthExceeded(self, line):
+        if len(line) > self.MAX_LENGTH+1:
+            self.setLineMode(line[self.MAX_LENGTH+1:])
+
 class WireTestCase(unittest.TestCase):
 
     def testEcho(self):
@@ -103,13 +109,14 @@
 012345678len 0
 foo 5
 
+1234567890123456789012345678901234567890123456789012345678901234567890
 len 1
 
 a'''
 
     output = ['len 10', '0123456789', 'len 5', '1234\n',
               'len 20', 'foo 123', '0123456789\n012345678',
-              'len 0', 'foo 5', '', 'len 1', 'a']
+              'len 0', 'foo 5', '', '67890', 'len 1', 'a']
 
     def testBuffer(self):
         for packet_size in range(1, 10):
@@ -175,3 +182,6 @@
             r.dataReceived(s)
             if not r.brokenPeer:
                 raise AssertionError("connection wasn't closed on illegal netstring %s" % repr(s))
+
+if __name__ == '__main__':
+    unittest.main()
Index: twisted/test/test_smtp.py
===================================================================
RCS file: /cvs/Twisted/twisted/test/test_smtp.py,v
retrieving revision 1.16
diff -u -u -r1.16 test_smtp.py
--- twisted/test/test_smtp.py	6 Sep 2002 09:09:42 -0000	1.16
+++ twisted/test/test_smtp.py	31 Oct 2002 20:25:04 -0000
@@ -26,7 +26,7 @@
 from twisted.protocols import loopback, smtp
 from twisted.internet import defer, protocol
 from twisted.test.test_protocols import StringIOWithoutClosing
-import string
+import string, re
 from cStringIO import StringIO
 
 class DummyMessage:
@@ -195,15 +195,17 @@
         self.buffer = []
 
     def lineReceived(self, line):
-        self.buffer.append(line)
+        # Throw away the generated Received: header
+        if not re.match('Received: From foo.com \(\[.*\]\) by foo.com;', line):
+            self.buffer.append(line)
 
     def eomReceived(self):
         message = string.join(self.buffer, '\n')+'\n'
-        helo, origin = self.users[0].helo, self.users[0].orig
+        helo, origin = self.users[0].helo[0], str(self.users[0].orig)
         recipients = []
         for user in self.users:
-            recipients.append(user.name+'@'+user.domain)
-        self.protocol.messages.append((helo, origin, recipients, message))
+            recipients.append(str(user))
+        self.protocol.message = (helo, origin, recipients, message)
         deferred = defer.Deferred()
         deferred.callback("saved")
         return deferred
@@ -212,51 +214,84 @@
 
     def connectionMade(self):
         smtp.SMTP.connectionMade(self)
-        self.messages = []
+        self.message = None
 
     def startMessage(self, users):
         return [DummySMTPMessage(self, users)]
 
-
 class AnotherSMTPTestCase(unittest.TestCase):
 
-    messages = [ ('foo.com', 'moshez at foo.com', ['moshez at bar.com'], '''\
+    messages = [ ('foo.com', 'moshez at foo.com', ['moshez at bar.com'],
+                  'moshez at foo.com', ['moshez at bar.com'], '''\
 From: Moshe
 To: Moshe
 
 Hi,
 how are you?
 '''),
-                 ('foo.com', 'tttt at rrr.com', ['uuu at ooo', 'yyy at eee'], '''\
+                 ('foo.com', 'tttt at rrr.com', ['uuu at ooo', 'yyy at eee'],
+                  'tttt at rrr.com', ['uuu at ooo', 'yyy at eee'], '''\
 Subject: pass
 
 ..rrrr..
-''')
+'''),
+                 ('foo.com', '@this, at is, at ignored:foo at bar.com',
+                  ['@ignore, at this, at too:bar at foo.com'],
+                  'foo at bar.com', ['bar at foo.com'], '''\
+Subject: apa
+To: foo
+
+123
+.
+456
+'''),
               ]
 
-    expected_output = '220 Spammers beware, your ass is on fire\015\012250 Nice to meet you\015\012250 From address accepted\015\012250 Address recognized\015\012354 Continue\015\012250 Delivery in progress\015\012250 From address accepted\015\012250 Address recognized\015\012250 Address recognized\015\012354 Continue\015\012250 Delivery in progress\015\012221 See you later\015\012'
-
-    input = 'HELO foo.com\r\n'
-    for _, from_, to_, message in messages:
-        input = input + 'MAIL FROM:<%s>\r\n' % from_
-        for item in to_:
-            input = input + 'RCPT TO:<%s>\r\n' % item
-        input = input + 'DATA\r\n'
-        for line in string.split(message, '\n')[:-1]:
-            if line[:1] == '.': line = '.' + line
-            input = input + line + '\r\n'
-        input = input + '.' + '\r\n'
-    input = input + 'QUIT\r\n'
+    testdata = [
+        ('', '220.*\r\n$', None, None),
+        ('HELO foo.com\r\n', '250.*\r\n$', None, None),
+        ('RSET\r\n', '250.*\r\n$', None, None),
+        ]
+    for helo_, from_, to_, realfrom, realto, msg in messages:
+        testdata.append(('MAIL FROM:<%s>\r\n' % from_, '250.*\r\n',
+                         None, None))
+        for rcpt in to_:
+            testdata.append(('RCPT TO:<%s>\r\n' % rcpt, '250.*\r\n',
+                             None, None))
+        testdata.append(('DATA\r\n','354.*\r\n',
+                         msg, ('250.*\r\n',
+                               (helo_, realfrom, realto, msg))))
+                                                       
 
     def testBuffer(self):
         output = StringIOWithoutClosing()
         a = DummySMTP()
-        a.makeConnection(protocol.FileWrapper(output))
-        a.dataReceived(self.input)
-        if a.messages != self.messages:
-            raise AssertionError(a.messages)
-        if output.getvalue() != self.expected_output:
-            raise AssertionError(`output.getvalue()`)
+        class fooFactory:
+            domain = 'foo.com'
 
+        a.factory = fooFactory()
+        a.makeConnection(protocol.FileWrapper(output))
+        for (send, expect, msg, msgexpect) in self.testdata:
+            if send:
+                a.dataReceived(send)
+            data = output.getvalue()
+            output.truncate(0)
+            if not re.match(expect, data):
+                raise AssertionError, (send, expect, data)
+            if data[:3] == '354':
+                for line in msg.splitlines():
+                    if line and line[0] == '.':
+                        line = '.' + line
+                    a.dataReceived(line + '\r\n')
+                a.dataReceived('.\r\n')
+                # Special case for DATA. Now we want a 250, and then
+                # we compare the messages
+                data = output.getvalue()
+                output.truncate()
+                resp, msgdata = msgexpect
+                if not re.match(resp, data):
+                    raise AssertionError, (resp, data)
+                if a.message != msgdata:
+                    raise AssertionError, (msgdata, a.message)
 
 testCases = [SMTPTestCase, SMTPClientTestCase, LoopbackSMTPTestCase, AnotherSMTPTestCase]


More information about the Twisted-Python mailing list