[Twisted-Python] Improvements to twisted.protocols.smtp

Anders Hammarquist iko at strakt.com
Mon Sep 30 11:01:36 EDT 2002


Hi!

We have a need for an SMTP server, and found the current implementation
a bit fragile. This patch should robustify it, and also improve RFC 2821
compliance (no guarantees though).

I didn't touch the SMTP client (apart from adding the double-dot protocol),
but I suspect I'll look over it as well.

Comments, etc are welcome. The patch is relative to twisted CVS as of a 
few minutes ago.

Regards,
/Anders

-- 
 -- Of course I'm crazy, but that doesn't mean I'm wrong.
Anders Hammarquist                                  | iko at strakt.com
AB Strakt                                           | Tel: +46 31 711 08 70
G|teborg, Sweden.           RADIO: SM6XMM and N2JGL | Fax: +46 31 711 08 80

-------------- next part --------------
Index: twisted/protocols/basic.py
===================================================================
RCS file: /cvs/Twisted/twisted/protocols/basic.py,v
retrieving revision 1.23
diff -u -u -r1.23 basic.py
--- twisted/protocols/basic.py	23 Sep 2002 08:51:29 -0000	1.23
+++ twisted/protocols/basic.py	30 Sep 2002 14:49:24 -0000
@@ -150,12 +150,14 @@
                 line, self.__buffer = self.__buffer.split(self.delimiter, 1)
             except ValueError:
                 if len(self.__buffer) > self.MAX_LENGTH:
-                    self.transport.loseConnection()
+                    line, self.__buffer = self.__buffer, ''
+                    self.lineLengthExceeded(line)
                     return
                 break
             else:
                 if len(line) > self.MAX_LENGTH:
-                    self.transport.loseConnection()
+                    line, self.__buffer = self.__buffer, ''
+                    self.lineLengthExceeded(line)
                     return
                 self.lineReceived(line)
                 if self.transport.disconnecting:
@@ -197,6 +199,12 @@
         """Sends a line to the other end of the connection.
         """
         self.transport.write(line + self.delimiter)
+
+    def lineLengthExceeded(self, line):
+        """Called when the maximum line length has been reached.
+        Override if it needs to be dealt with in some special way.
+        """
+        self.transport.loseConnection()
 
 
 class Int32StringReceiver(protocol.Protocol):
Index: twisted/protocols/smtp.py
===================================================================
RCS file: /cvs/Twisted/twisted/protocols/smtp.py,v
retrieving revision 1.23
diff -u -u -r1.23 smtp.py
--- twisted/protocols/smtp.py	19 Aug 2002 03:21:58 -0000	1.23
+++ twisted/protocols/smtp.py	30 Sep 2002 14:49:24 -0000
@@ -18,10 +18,10 @@
 """
 
 from twisted.protocols import basic
-from twisted.internet import protocol, defer
+from twisted.internet import protocol, defer, reactor
 from twisted.python import log
 
-import os, time, string, operator
+import os, time, string, operator, re
 
 class SMTPError(Exception):
     pass
@@ -49,15 +49,93 @@
         self.deferred.errback(arg)
         self.done = 1
 
+class AddressError(SMTPError):
+    "Parse error in address"
+    pass
+
+# Character classes for parsing addresses
+atom = r"-A-Za-z0-9!#$%&'*+/=?^_`{|}~"
+
+class Address:
+    """Parse and hold an RFC 2821 address.
+
+    Source routes are stipped and ignored, UUCP-style bang-paths
+    and %-style routing are not parsed.
+    """
+
+    qstring = re.compile(r'((?:"[^"]*"|\\.|[' + atom + r'])+|.)')
+
+    def __init__(self, addr):
+        self.local = ''
+        self.domain = ''
+        self.addrstr = addr
+
+        # Tokenize
+        atl = filter(None,self.qstring.split(addr))
+
+        local = []
+        domain = []
+
+        while atl:
+            if atl[0] == '<':
+                if atl[-1] != '>':
+                    raise AddressError, "Unbalanced <>"
+                atl = atl[1:-1]
+            elif atl[0] == '@':
+                atl = atl[1:]
+                if not local:
+                    # Source route
+                    while atl and atl[0] != ':':
+                        # remove it
+                        atl = atl[1:]
+                    if not atl:
+                        raise AddressError, "Malformed source route"
+                    atl = atl[1:] # remove :
+                elif domain:
+                    raise AddressError, "Too many @"
+                else:
+                    # Now in domain
+                    domain = ['']
+            elif len(atl[0]) == 1 and atl[0] not in atom + '.':
+                raise AddressError, "Parse error at " + atl[0]
+            else:
+                if not domain:
+                    local.append(atl[0])
+                else:
+                    domain.append(atl[0])
+                atl = atl[1:]
+               
+        self.local = ''.join(local)
+        self.domain = ''.join(domain)
+
+    dequotebs = re.compile(r'\\(.)')
+    def dequote(self, addr):
+        "Remove RFC-2821 quotes from address"
+        res = []
+
+        atl = filter(None,self.qstring.split(addr))
+
+        for t in atl:
+            if t[0] == '"' and t[-1] == '"':
+                res.append(t[1:-1])
+            elif '\\' in t:
+                res.append(self.dequotebs.sub(r'\1',t))
+            else:
+                res.append(t)
+
+        return ''.join(res)
+
+    def __str__(self):
+        return '%s%s' % (self.local, self.domain and ("@" + self.domain) or "")
+
+    def __repr__(self):
+        return "%s.%s(%s)" % (self.__module__, self.__class__.__name__,
+                              repr(str(self)))
 
-class User:
+class User(Address):
 
     def __init__(self, destination, helo, protocol, orig):
-        try:
-            self.name, self.domain = string.split(destination, '@', 1)
-        except ValueError:
-            self.name = destination
-            self.domain = ''
+        Address.__init__(self,destination)
         self.helo = helo
         self.protocol = protocol
         self.orig = orig
@@ -83,23 +161,41 @@
 
 class SMTP(basic.LineReceiver):
 
-    def __init__(self):
+    def __init__(self, domain=None, timeout=600):
         self.mode = COMMAND
         self.__from = None
         self.__helo = None
-        self.__to = ()
+        self.__to = []
+        self.timeout = timeout
+        if not domain:
+            import socket
+            domain = socket.getfqdn()
+        self.host = domain
+
+    def timedout(self):
+        self.sendCode(421, '%s Timeout. Try talking faster next time!' %
+                      self.host)
+        self.transport.loseConnection()
 
     def connectionMade(self):
-        self.sendCode(220, 'Spammers beware, your ass is on fire')
+        self.sendCode(220, '%s Spammers beware, your ass is on fire' %
+                      self.host)
+        self.timeoutID = reactor.callLater(self.timeout, self.timedout)
 
     def sendCode(self, code, message=''):
         "Send an SMTP code with a message."
         self.transport.write('%d %s\r\n' % (code, message))
 
     def lineReceived(self, line):
+        self.timeoutID.cancel()
+        self.timeoutID = reactor.callLater(self.timeout, self.timedout)
+
         if self.mode is DATA:
             return self.dataLineReceived(line)
-        command = string.split(line, None, 1)[0]
+        if line:
+            command = string.split(line, None, 1)[0]
+        else:
+            command = ''
         method = getattr(self, 'do_'+string.upper(command), None)
         if method is None:
             method = self.do_UNKNOWN
@@ -107,21 +203,59 @@
             line = line[len(command):]
         return method(string.strip(line))
 
+    def lineLengthExceeded(self, line):
+        if self.mode is DATA:
+            for message in self.__messages:
+                message.connectionLost()
+            self.mode = COMMAND
+            del self.__messages
+        self.sendCode(500, 'Line too long')
+
     def do_UNKNOWN(self, rest):
-        self.sendCode(502, 'Command not implemented')
+        self.sendCode(500, 'Command not implemented')
 
     def do_HELO(self, rest):
-        self.__helo = rest
-        self.sendCode(250, 'Nice to meet you')
+        peer = self.transport.getPeer()[1]
+        self.__helo = (rest, peer)
+        self.sendCode(250, '%s Hello %s, nice to meet you' % (self.host, peer))
 
     def do_QUIT(self, rest):
         self.sendCode(221, 'See you later')
         self.transport.loseConnection()
 
+    qstring = r'("[^"]*"|\\.|[' + atom + r'@.])+'
+
+    path_re = re.compile(r"(<(?=.*>))?(?P<addr><>|(?<=<)" + qstring + r"(?=(?<!\\)>)|" + qstring + r")((?<![\\<])>)?$")
+    mail_re = re.compile(r'\s*FROM:\s*(?P<path><>|<' + qstring + r'>|' +
+                         qstring + r')\s*(?P<opts>.*)$',re.I)
+    rcpt_re = re.compile(r'\s*TO:\s*(?P<path><' + qstring + r'>|' +
+                         qstring + r')\s*(?P<opts>.*)$',re.I)
+
     def do_MAIL(self, rest):
-        from_ = rest[len("MAIL:<"):-len(">")]
-        self.validateFrom(self.__helo, from_, self._fromValid,
-                                              self._fromInvalid)
+        if not self.__helo:
+            self.sendCode(503,"Who are you? Say HELO first");
+        if self.__from:
+            self.sendCode(503,"Only one sender per message, please")
+            return
+        # Clear old recipient list
+        self.__to = []
+        m = self.mail_re.match(rest)
+        if not m:
+            self.sendCode(501, "Syntax error")
+            return
+        m = self.path_re.match(m.group('path'))
+        if not m:
+            self.sendCode(553, "Unparseable address")
+            return
+
+        try:
+            addr = Address(m.group('addr'))
+        except AddressError, e:
+            self.sendCode(553, str(e))
+            return
+            
+        self.validateFrom(self.__helo, addr, self._fromValid,
+                          self._fromInvalid)
 
     def _fromValid(self, from_):
         self.__from = from_
@@ -131,12 +265,28 @@
         self.sendCode(550, 'No mail for you!')
 
     def do_RCPT(self, rest):
-        to = rest[len("TO:<"):-len(">")]
-        user = User(to, self.__helo, self, self.__from)
+        if not self.__from:
+            self.sendCode(503, "Must have sender before recpient")
+            return
+        m = self.rcpt_re.match(rest)
+        if not m:
+            self.sendCode(501, "Syntax error")
+            return
+        m = self.path_re.match(m.group('path'))
+        if not m:
+            self.sendCode(553, "Unparseable address")
+            return
+
+        try:
+            user = User(m.group('addr'), self.__helo, self, self.__from)
+        except AddressError, e:
+            self.sendCode(553, str(e))
+            return
+            
         self.validateTo(user, self._toValid, self._toInvalid)
 
     def _toValid(self, to):
-        self.__to = self.__to + (to,)
+        self.__to.append(to)
         self.sendCode(250, 'Address recognized')
 
     def _toInvalid(self, to):
@@ -144,22 +294,29 @@
 
     def do_DATA(self, rest):
         if self.__from is None or not self.__to:  
-            self.sendCode(550, 'Must have valid receiver and originator')
+            self.sendCode(503, 'Must have valid receiver and originator')
             return
         self.mode = DATA
         helo, origin, recipients = self.__helo, self.__from, self.__to
         self.__from = None
-        self.__to = ()
+        self.__to = []
         self.__messages = self.startMessage(recipients)
+        for message in self.__messages:
+            message.lineReceived(self.receivedHeader(helo, origin, recipients))
         self.sendCode(354, 'Continue')
 
     def connectionLost(self, reason):
+        # self.sendCode(421, 'Loosing connection.') # This does nothing...
+        # Ideally, if we (rather than the other side) lose the connection,
+        # we should be able to tell the other side that we are going away.
+        # RFC-2821 requires that we try.
         if self.mode is DATA:
             for message in self.__messages:
                 message.connectionLost()
 
     def do_RSET(self, rest):
-        self.__init__()
+        self.__from = None
+        self.__to = []
         self.sendCode(250, 'I remember nothing.')
 
     def dataLineReceived(self, line):
@@ -177,6 +334,7 @@
                     deferred = message.eomReceived()
                     deferred.addCallback(ndeferred.callback)
                     deferred.addErrback(ndeferred.errback)
+                del self.__messages
                 return
             line = line[1:]
         for message in self.__messages:
@@ -189,6 +347,11 @@
         self.sendCode(550, 'Could not send e-mail')
 
     # overridable methods:
+    def receivedHeader(self, helo, origin, recipents):
+        return "Received: From %s ([%s]) by %s; %s" % (
+            helo[0], helo[1], self.host,
+            time.strftime("%a, %d %b %Y %H:%M:%S +0000", time.gmtime()))
+    
     def validateFrom(self, helo, origin, success, failure):
         success(origin)
 
@@ -281,6 +444,7 @@
             self.state = 'afterData'
 
         chunk = string.replace(chunk, "\n", "\r\n")
+        chunk = string.replace(chunk, "\r\n.", "\r\n..")
         self.transport.write(chunk)
 
     def pauseProducing(self):


More information about the Twisted-Python mailing list