[Twisted-Python] Patch to support DNS updates

Jeff Silver jeff at jamcupboard.co.uk
Fri Feb 25 06:26:50 EST 2005


Here is a patch to extend twisted.names to support DNS update messages.
Thanks to those who answered my previous queries.

Attachments:
    update_descr.txt      Description
    authority.py.diff     diff -u of twisted/names/authority.py
    server.py.diff        diff -u of twisted/names/server.py
    tap.py.diff           diff -u of twisted/names/tap.py
    dns.py.diff           diff -u of twisted/protocols/dns.py
-------------- next part --------------
Extension to twisted to implement UPDATE requests
=================================================

Changed files:
    names/authority.py
    names/server.py
    names/tap.py
    protocols/dns.py

Description:

    Updates are accepted for zones for which this server is the
    authority. These updates are applied to the running server.
    If the zones were loaded through the --pyzone or --bindzone
    options, the changes are not written back to disk. If the
    server is run with --nosave, the updates are lost when the
    server shuts down.

    Two new mktap options are added: --pyzonefile and --bindzonefile
    These mean the same as the existing options without the 'file'
    part, except:
        a) At mktap time, the filenames are recorded but the files
           are not read. They are read on server start-up. This
           means that you can edit the files and restart the server
           to implement changes without doing another mktap.
        b) When an update has been applied, the appropriate file
           is renamed by having its serial number appended, and a
           new copy is written. This means that updates persist
           across server restarts, even if --nosave is specified.

Limitations and oddities:

    There is no forwarding of update requests.

    There is no processing of prerequisites.

    There is no permission checking.

    Only SOA, A, and MX records are written to the new files. (Just
    laziness on my part. I only need A records at present.)

    If an error occurs on writing the file, an error is returned
    to the client, but the update survives in the running server.
    (This contravenes section 3.5 of RFC 2136.)

    I haven't done anything to ensure atomicity of operations
    or serialization of updates as described in section 3.7 of
    RFC 2136.

    There seems to be some confusion as to the meaning of the
    fifth (last) field in an SOA record. (O'Reilly's DNS and BIND
    book is itself ambiguous.) This field is now supposed to be
    the negative caching TTL; previously it was the default TTL.
    twisted.names copes correctly with a $TTL line in BIND-format
    files. But it stores the fifth value in 'minimum', and the
    handling of Python source format files doesn't seem to have
    any equivalent to the $TTL line. I've tried to maintain
    consistency with what's already there, but I'm not sure I've
    done the right thing with this.

    I agree with the comment in the code "this might suck" about
    using the source filename as the zone origin, but I haven't
    made any changes to this. I'm not sure what exactly the zone
    origin is supposed to be.
-------------- next part --------------
--- names/authority.py	2004-02-07 17:57:11.000000000 +0000
+++ names/authority.py.new	2005-02-25 09:34:02.725600288 +0000
@@ -161,6 +161,78 @@
                 add.extend(res[1][2])
         return ans, auth, add
 
+    def _trySaveToOriginalFile(self):
+        '''Called after an update.'''
+        soa_rec = self.soa[1]
+        new_serial = int(time.strftime('%Y%m%d%H%M')) - 200000000000
+        if new_serial <= soa_rec.serial:
+            new_serial = soa_rec.serial + 1
+        old_serial = soa_rec.serial
+        soa_rec.serial = new_serial
+        if hasattr(self, 'filename'):
+            tmp_filename = self.filename + '.new'
+            save_filename = '%s.%d' % (self.filename, old_serial)
+            self.saveFile(tmp_filename)
+            os.rename(self.filename, save_filename)
+            os.rename(tmp_filename, self.filename)
+
+    def addRR(self, update):
+        key = str(update.name)
+        try:
+            rrlist = self.records[key]
+        except KeyError:
+            # create new entry
+            rrlist = self.records[key] = []
+        record = update.payload
+        for rec in rrlist:
+            if record == rec:
+                # duplicate. ignore
+                return
+        rrlist.append(record)
+        self._trySaveToOriginalFile()
+
+    def deleteName(self, record):
+        try:
+            del self.records[str(record.name)]
+        except KeyError:
+            return
+        self._trySaveToOriginalFile()
+
+    def deleteRRset(self, record):
+        try:
+            rrlist = self.records[str(record.name)]
+        except KeyError:
+            return
+        index = 0
+        did_it = False
+        while index < length(rrlist):
+            if record.type == rrlist[index].type:
+                del rrlist[index]
+                did_it = True
+            else:
+                index += 1
+        if did_it:
+            self._trySaveToOriginalFile()
+
+    def deleteRR(self, record):
+        try:
+            rrlist = self.records[str(record.name)]
+        except KeyError:
+            return
+        rec_to_delete = record.payload
+        index = 0
+        did_it = False
+        while index < len(rrlist):
+            if rec_to_delete.__class__ == rrlist[index].__class__ \
+                    and rec_to_delete.preference == rrlist[index].preference \
+                    and rec_to_delete.exchange == rrlist[index].exchange:
+                del rrlist[index]
+                did_it = True
+            else:
+                index += 1
+        if did_it:
+            self._trySaveToOriginalFile()
+
 
 class PySourceAuthority(FileAuthority):
     """A FileAuthority that is built up from Python source code."""
@@ -177,6 +249,26 @@
                 self.soa = rr
             self.records.setdefault(rr[0].lower(), []).append(rr[1])
 
+    def saveFile(self, filename):
+        of = file(filename, 'w')
+        names = self.records.keys()
+        names.sort()
+        print >>of, 'zone = ['
+        for name in names:
+            for rec in self.records[name]:
+                print >>of, "    %s('%s'," % (dns.QUERY_TYPES[rec.TYPE], name),
+                if rec.TYPE == dns.SOA:
+                    print >>of, "serial=%d, refresh=%d, retry=%d, expire=%d, minimum=%d)," % \
+                        (rec.serial, rec.refresh, rec.retry, rec.expire, rec.minimum)
+                elif rec.TYPE == dns.A:
+                    print >>of, "'%s')," % rec.dottedQuad()
+                elif rec.TYPE == dns.MX:
+                    print >>of, "%d, '%s')," % (rec.preference, rec.exchange)
+                else:
+                    print >>of, '),'
+        print >>of, ']'
+        of.close()
+
 
     def wrapRecord(self, type):
         return lambda name, *arg, **kw: (name, type(*arg, **kw))
@@ -194,7 +286,7 @@
 
 class BindAuthority(FileAuthority):
     """An Authority that loads BIND configuration files"""
-    
+
     def loadFile(self, filename):
         self.origin = os.path.basename(filename) + '.' # XXX - this might suck
         lines = open(filename).readlines()
@@ -202,6 +294,37 @@
         lines = self.collapseContinuations(lines)
         self.parseLines(lines)
 
+    def saveFile(self, filename):
+        of = file(filename, 'w')
+        soa_rec = self.soa[1]
+        print >>of, '$TTL %d' % soa_rec.ttl
+        print >>of, '%s. IN SOA %s %s (' % (self.soa[0], soa_rec.mname, soa_rec.rname)
+        for val in (soa_rec.serial, soa_rec.refresh, soa_rec.retry, soa_rec.expire, soa_rec.minimum):
+            print >>of, '\t%d' % val
+        print >>of, ')'
+        dotted_orig = '.' + self.origin
+        names = self.records.keys()
+        names.sort()
+        for name in names:
+            reclist = self.records[name]
+            if name.endswith(dotted_orig):
+                name = name[:-len(dotted_orig)]
+            else:
+                name = name + '.'
+            for rec in reclist:
+                if rec.TYPE == dns.SOA:
+                    continue
+                print >>of, '%s' % name,
+                if rec.ttl != soa_rec.ttl:
+                    print >>of, '%d' % (rec.ttl),
+                print >>of, 'IN %s' % dns.QUERY_TYPES[rec.TYPE],
+                if rec.TYPE == dns.A:
+                    print >>of, '%s' % rec.dottedQuad()
+                elif rec.TYPE == dns.MX:
+                    print >>of, '%d %s' % (rec.preference, rec.exchange)
+                else:
+                    print >>of, ''
+        of.close()
 
     def stripComments(self, lines):
         return [
@@ -279,9 +402,6 @@
             raise NotImplementedError, "Record type %r not supported" % type
 
 
-    #
-    # This file ends here.  Read no further.
-    #
     def parseRecordLine(self, origin, ttl, line):
         MARKERS = dns.QUERY_CLASSES.values() + dns.QUERY_TYPES.values()
         cls = 'IN'
-------------- next part --------------
--- names/server.py	2004-03-01 22:54:20.000000000 +0000
+++ names/server.py.new	2005-02-25 09:34:01.529782080 +0000
@@ -40,6 +40,8 @@
 from twisted.internet import protocol, defer
 from twisted.protocols import dns
 from twisted.python import failure, log
+from twisted.names import authority
 
 import resolve, common
 
@@ -47,21 +49,38 @@
     protocol = dns.DNSProtocol
     cache = None
 
-    def __init__(self, authorities = None, caches = None, clients = None, verbose = 0):
-        resolvers = []
+    def __init__(self, authorities = None, bindfilenames=None, pyfilenames=None, caches = None, clients = None, verbose = 0):
+        self.resolvers = []
         if authorities is not None:
-            resolvers.extend(authorities)
+            self.resolvers.extend(authorities)
         if caches is not None:
-            resolvers.extend(caches)
+            self.resolvers.extend(caches)
         if clients is not None:
-            resolvers.extend(clients)
+            self.resolvers.extend(clients)
+
+        # save authority list for use during updates
+        self.authorities = authorities
+
+        self.bindfilenames = bindfilenames
+        self.pyfilenames = pyfilenames
 
         self.canRecurse = not not clients
-        self.resolver = resolve.ResolverChain(resolvers)
         self.verbose = verbose
         if caches:
             self.cache = caches[-1]
 
+    def startFactory(self):
+        for f in self.bindfilenames:
+            new_auth = authority.BindAuthority(f)
+            new_auth.filename = f   # save filename for rewriting file on update
+            self.authorities.append(new_auth)
+            self.resolvers.append(new_auth)
+        for f in self.pyfilenames:
+            new_auth = authority.PySourceAuthority(f)
+            new_auth.filename = f   # save filename for rewriting file on update
+            self.authorities.append(new_auth)
+            self.resolvers.append(new_auth)
+        self.resolver = resolve.ResolverChain(self.resolvers)
 
     def buildProtocol(self, addr):
         p = self.protocol(self)
@@ -159,6 +178,68 @@
             log.msg("Notify message from %r" % (address,))
 
 
+    def handleUpdate(self, message, protocol, address):
+        if self.verbose:
+            log.msg("Update message from %r" % (address,))
+        query = dns.Query(message.zones[0].name.name, dns.SOA)
+        return self.authQuery(query, self.authorities[:], protocol, message, address)
+
+    def authQuery(self, query, auth_list, protocol, message, address):
+        return auth_list[0].query(query).addCallback(
+            self.gotResolverResponseAuth, query, auth_list, protocol, message, address
+        ).addErrback(
+            self.gotResolverErrorAuth, query, auth_list, protocol, message, address
+        )
+
+    def gotResolverResponseAuth(self, (ans, auth, add), query, auth_list, protocol, message, address):
+        if not ans[0].auth:
+            if self.verbose:
+                log.msg("we are not the authority for this domain")
+            message.rCode = dns.REFUSED
+        else:
+            auth = auth_list[0]
+            message.rCode = dns.OK
+            did_update = False
+            zone_class = message.zones[0].cls
+            for update in message.updates:
+                if update.cls == zone_class:
+                    if self.verbose:
+                        log.msg("add %s" % (update,))
+                    auth.addRR(update)
+                    did_update = True
+                elif update.ttl == 0 and update.cls == dns.ANY:
+                    if update.type == dns.ANY:
+                        if self.verbose:
+                            log.msg("delete name %s" % (update,))
+                        auth.deleteName(update)
+                        did_update = True
+                    else:
+                        if self.verbose:
+                            log.msg("delete RRset %s" % (update,))
+                        auth.deleteRRset(update)
+                        did_update = True
+                elif update.ttl == 0 and update.cls == dns.NONE:
+                    if self.verbose:
+                        log.msg("delete RR %s" % (update,))
+                    auth.deleteRR(update)
+                    did_update = True
+                else:
+                    if self.verbose:
+                        log.msg("bad combination of values %s" % (update,))
+                    message.rCode = dns.EFORMAT
+        self.sendReply(protocol, message, address)
+
+    def gotResolverErrorAuth(self, failure, query, auth_list, protocol, message, address):
+        del auth_list[0]
+        if auth_list:
+            # keep going
+            return self.authQuery(query, auth_list, protocol, message, address)
+        if self.verbose:
+            log.msg("domain not found")
+        message.rCode = dns.ENAME
+        self.sendReply(protocol, message, address)
+
+
     def handleOther(self, message, protocol, address):
         message.rCode = dns.ENOTIMP
         self.sendReply(protocol, message, address)
@@ -194,6 +275,8 @@
             self.handleStatus(message, proto, address)
         elif message.opCode == dns.OP_NOTIFY:
             self.handleNotify(message, proto, address)
+        elif message.opCode == dns.OP_UPDATE:
+            self.handleUpdate(message, proto, address)
         else:
             self.handleOther(message, proto, address)
 
-------------- next part --------------
--- names/tap.py	2003-12-05 04:54:03.000000000 +0000
+++ names/tap.py.new	2005-02-25 09:34:02.726600136 +0000
@@ -50,7 +50,9 @@
         usage.Options.__init__(self)
         self['verbose'] = 0
         self.bindfiles = []
+        self.bindfilenames = []
         self.zonefiles = []
+        self.pyfilenames = []
         self.secondaries = []
 
 
@@ -66,6 +68,20 @@
             raise usage.UsageError(filename + ": No such file")
         self.bindfiles.append(filename)
 
+    def opt_bindzonefile(self, filename):
+        """Specify the filename of a BIND9 syntax zone definition to be
+        read at start-up time (not at mktap time)."""
+        if not os.path.exists(filename):
+            raise usage.UsageError(filename + ": No such file")
+        self.bindfilenames.append(filename)
+
+    def opt_pyzonefile(self, filename):
+        """Specify the filename of a Python source zone definition to be
+        read at start-up time (not at mktap time)."""
+        if not os.path.exists(filename):
+            raise usage.UsageError(filename + ": No such file")
+        self.pyfilenames.append(filename)
+
 
     def opt_secondary(self, ip_domain):
         """Act as secondary for the specified domain, performing
@@ -119,7 +135,7 @@
     if config['recursive']:
         cl.append(client.createResolver(resolvconf=config['resolv-conf']))
 
-    f = server.DNSServerFactory(config.zones, ca, cl, config['verbose'])
+    f = server.DNSServerFactory(config.zones, config.bindfilenames, config.pyfilenames, ca, cl, config['verbose'])
     p = dns.DNSDatagramProtocol(f)
     f.noisy = 0
     ret = service.MultiService()
-------------- next part --------------
--- protocols/dns.py	2004-02-25 20:16:12.000000000 +0000
+++ protocols/dns.py.new	2005-02-25 09:34:01.528782232 +0000
@@ -88,7 +88,9 @@
 
     33: 'SRV',
     
-    38: 'A6', 39: 'DNAME'
+    38: 'A6', 39: 'DNAME',
+
+    255: 'ANY',  # for updates. Is this the right place, or should parsing be different for queries/updates?
 }
 
 # "Extended" queries (Hey, half of these are deprecated, good job)
@@ -105,7 +107,7 @@
 
 
 QUERY_CLASSES = {
-    1: 'IN',  2: 'CS',  3: 'CH',  4: 'HS',  255: 'ANY'
+    1: 'IN',  2: 'CS',  3: 'CH',  4: 'HS',  254: 'NONE',  255: 'ANY'
 }
 REV_CLASSES = dict([
     (v, k) for (k, v) in QUERY_CLASSES.items()
@@ -116,10 +118,11 @@
 
 
 # Opcodes
-OP_QUERY, OP_INVERSE, OP_STATUS, OP_NOTIFY = range(4)
+OP_QUERY, OP_INVERSE, OP_STATUS, OP_NOTIFY, OP_UNKNOWN, OP_UPDATE = range(6)
 
 # Response Codes
-OK, EFORMAT, ESERVER, ENAME, ENOTIMP, EREFUSED = range(6)
+OK, EFORMAT, ESERVER, ENAME, ENOTIMP, EREFUSED, YXDOMAIN, YXRRSET, NXRRSET = range(9)
+
 
 class IRecord(components.Interface):
     """An single entry in a zone of authority.
@@ -454,7 +457,8 @@
 
     def decode(self, strio, length = None):
         self.name = Name()
-        self.name.decode(strio)
+        if length:
+            self.name.decode(strio)
     
 
     def __hash__(self):
@@ -462,6 +466,9 @@
 
 
 # Kinds of RRs - oh my!
+class Record_ANY(SimpleRecord):  # for updates. Is this the right place, or should parsing be different for queries/updates?
+    TYPE = ''
+
 class Record_NS(SimpleRecord):
     TYPE = NS
 
@@ -1004,6 +1011,7 @@
         self.recAv = ( byte4 >> 7 ) & 1
         self.rCode = byte4 & 0xf
 
+        # query
         self.queries = []
         for i in range(nqueries):
             q = Query()
@@ -1017,6 +1025,12 @@
         for (l, n) in items:
             self.parseRecords(l, n, strio)
 
+        if self.opCode == OP_UPDATE:
+            # rename fields for readability
+            self.zones = self.queries
+            self.prereqs = self.answers
+            self.updates = self.authority
+
 
     def parseRecords(self, list, num, strio):
         for i in range(num):


More information about the Twisted-Python mailing list