Ticket #4610: twisted-dnssec.patch

File twisted-dnssec.patch, 13.3 KB (added by philmayers, 9 years ago)

Basic DNSKEY and RRSIG support, including ability to validate a set of RRs against a DNSKEY

  • twisted/names/dns.py

    diff --git a/twisted/names/dns.py b/twisted/names/dns.py
    index 25f78de..2793281 100644
    a b __all__ = [ 
    4747import warnings
    4848
    4949import struct, random, types, socket
     50import time
     51try:
     52    import hashlib
     53except:
     54    hashlib = None
     55
     56try:
     57    import Crypto.Util
     58    import Crypto.PublicKey.RSA
     59    _has_crypto = True
     60except:
     61    _has_crypto = False
    5062
    5163try:
    5264    import cStringIO as StringIO
    SRV = 33 
    8294NAPTR = 35
    8395A6 = 38
    8496DNAME = 39
     97RRSIG = 46
     98NSEC = 47
     99DNSKEY = 48
    85100
    86101QUERY_TYPES = {
    87102    A: 'A',
    QUERY_TYPES = { 
    109124    SRV: 'SRV',
    110125    NAPTR: 'NAPTR',
    111126    A6: 'A6',
    112     DNAME: 'DNAME'
     127    DNAME: 'DNAME',
     128
     129    RRSIG: 'RRSIG',
     130    NSEC: 'NSEC',
     131    DNSKEY: 'DNSKEY',
    113132}
    114133
    115134IXFR, AXFR, MAILB, MAILA, ALL_RECORDS = range(251, 256)
    class SimpleRecord(tputil.FancyStrMixin, tputil.FancyEqMixin): 
    545564    def encode(self, strio, compDict = None):
    546565        self.name.encode(strio, compDict)
    547566
     567    def canonical(self, strio):
     568        n = Name(self.name.lower())
     569        n.encode(strio)
    548570
    549571    def decode(self, strio, length = None):
    550572        self.name = Name()
    551573        self.name.decode(strio)
    552574
    553 
    554575    def __hash__(self):
    555576        return hash(self.name)
    556577
    class Record_MX(tputil.FancyStrMixin, tputil.FancyEqMixin): 
    14121433        return hash((self.preference, self.name))
    14131434
    14141435
     1436class Record_RRSIG(tputil.FancyEqMixin, tputil.FancyStrMixin):
     1437    implements(IEncodable, IRecord)
     1438
     1439    TYPE = RRSIG
     1440
     1441    fancybasename = 'RRSIG'
     1442    showAttributes = compareAttributes = ('type', 'algorithm', 'labels', 'original_ttl', 'expiration', 'inception', 'keytag', 'signame', 'signature')
     1443
     1444    def __init__(self, signame='', signature='', type=A, algorithm=5, labels=0, original_ttl=0, expiration=None, inception=None, keytag=0):
     1445        self.signame = Name(signame)
     1446        self.signature = signature
     1447        self.type = type
     1448        self.algorithm = algorithm
     1449        self.labels = labels
     1450        self.original_ttl = original_ttl
     1451        if inception is None:
     1452            inception = int(time.time())
     1453        self.inception = inception
     1454        if expiration is None:
     1455            expiration = self.inception + 3600
     1456        self.expiration = expiration
     1457        self.keytag = keytag
     1458
     1459    _fmt = '!HBBIIIH'
     1460    _fmt_size = struct.calcsize(_fmt)
     1461
     1462    def encode(self, strio, compDict=None):
     1463        strio.write(struct.pack(self._fmt, self.type, self.algorithm, self.labels, self.original_ttl, self.expiration, self.inception, self.keytag))
     1464        self.signame.encode(strio, compDict)
     1465        strio.write(self.signature)
     1466
     1467    def decode(self, strio, length):
     1468        if length < self._fmt_size + 1:
     1469            raise Exception('payload too short')
     1470        hdr = readPrecisely(strio, self._fmt_size)
     1471        self.type, self.algorithm, self.labels, self.original_ttl, self.expiration, self.inception, self.keytag = struct.unpack(self._fmt, hdr)
     1472
     1473        length -= self._fmt_size
     1474
     1475        start = strio.tell()
     1476        self.signame = Name()
     1477        self.signame.decode(strio, length)
     1478        end = strio.tell()
     1479
     1480        length -= end - start
     1481
     1482        self.signature = readPrecisely(strio, length)
     1483
     1484    def validate(self, key, original):
     1485        # build the payload to validate
     1486        strio = StringIO.StringIO()
     1487        strio.write(struct.pack(self._fmt, self.type, self.algorithm, self.labels, self.original_ttl, self.expiration, self.inception, self.keytag))
     1488        self.signame.encode(strio)
     1489
     1490        canon = []
     1491
     1492        for rr in original:
     1493            if not rr.type==self.type:
     1494                continue
     1495            payload = rr.payload
     1496
     1497            # FIXME: assert rr.type=self.type?
     1498            # FIXME: assert rr.name==self.rrheader.name?
     1499
     1500            # RFC4034 section 6.2
     1501            io = StringIO.StringIO()
     1502
     1503            if hasattr(payload, 'canonical'):
     1504                payload.canonical(io)
     1505            else:
     1506                payload.encode(io)
     1507
     1508            val = io.getvalue()
     1509            canon.append((val, rr.name, rr.cls))
     1510
     1511        for payload,name,cls in sorted(canon):
     1512            name.encode(strio)
     1513            strio.write(struct.pack('!HHIH', self.type, cls, self.original_ttl, len(payload)))
     1514            strio.write(payload)
     1515           
     1516        sigvalue = strio.getvalue()
     1517        return key.verify(sigvalue, self.signature)
     1518
     1519class Record_DNSKEY(tputil.FancyEqMixin, tputil.FancyStrMixin):
     1520    implements(IEncodable, IRecord)
     1521
     1522    TYPE = DNSKEY
     1523
     1524    fancybasename = 'DNSKEY'
     1525    showAttributes = compareAttributes = ('flags', 'protocol', 'algorithm', 'key')
     1526
     1527    def __init__(self, flags=0, protocol=3, algorithm=5, key=''):
     1528        self.flags = flags
     1529        self.protocol = protocol
     1530        self.algorithm = algorithm
     1531        self.key = key
     1532
     1533    def tag(self):
     1534        data = struct.pack('!HBB', self.flags, self.protocol, self.algorithm) + self.key
     1535        v = 0
     1536        for i in range(len(data)):
     1537            if i & 1:
     1538                v += ord(data[i])
     1539            else:
     1540                v += ord(data[i]) << 8
     1541        v += (v >> 16) & 0xffff
     1542        return v & 0xffff
     1543
     1544    _fmt = '!HBB'
     1545    _fmt_size = struct.calcsize(_fmt)
     1546
     1547    def decode(self, strio, length):
     1548        if length < self._fmt_size:
     1549            raise Exception('too short')
     1550        hdr = readPrecisely(strio, self._fmt_size)
     1551        self.flags, self.protocol, self.algorithm = struct.unpack(self._fmt, hdr)
     1552        length -= self._fmt_size
     1553
     1554        self.key = readPrecisely(strio, length)
     1555
     1556    def encode(self, strio, compDict=None):
     1557        strio.write(struct.pack(self._fmt, self.flags, self.protocol, self.algorithm))
     1558        strio.write(self.key)
     1559
     1560    def verify(self, body, sig):
     1561        if self.algorithm==5:
     1562            return self.verify_rsa_sha1(body, sig)
     1563        raise Exception('unhandled algorithm')
     1564
     1565    def verify_rsa_sha1(self, data, sigvalue):
     1566        # key is either:
     1567        # 1-255=N (exponent len)
     1568        # N bytes exponent
     1569        # rest modulus
     1570        #
     1571        # or
     1572        #
     1573        # \x00
     1574        # 2 bytes exponent len
     1575        # N bytes exponent
     1576        # rest modulus
     1577        key = self.key
     1578        if key[0]=='\x00':
     1579            _ignore, explen = struct.unpack('!BH', key[3:])
     1580            key = key[3:]
     1581        else:
     1582            explen = ord(key[0])
     1583            key = key[1:]
     1584
     1585        exponent = Crypto.Util.number.bytes_to_long(key[:explen])
     1586        modbytes = key[explen:]
     1587        modlen = len(modbytes)
     1588        modulus = Crypto.Util.number.bytes_to_long(modbytes)
     1589
     1590        # RSA-SHA1
     1591        # RFC3110 section 3
     1592        hash = hashlib.new('sha1', data).digest()
     1593        prefix = '\x30\x21\x30\x09\x06\x05\x2B\x0E\x03\x02\x1A\x05\x00\x04\x14'
     1594        padlen = modlen - len(hash) - len(prefix) - 3
     1595        hash = '\x01' + '\xff'*padlen + '\x00' + prefix + hash
     1596
     1597        #print "verifying"
     1598        #print self.tag()
     1599        #print exponent
     1600        #for k in ('modbytes', 'data', 'hash', 'sigvalue'):
     1601        #    v = locals()[k]
     1602        #    print "%10s %4d %r" % (k, len(v), v)
     1603
     1604        # signature -> number
     1605        sig = Crypto.Util.number.bytes_to_long(sigvalue)
     1606
     1607        # verify
     1608        key = Crypto.PublicKey.RSA.construct((modulus, exponent))
     1609
     1610        return key.verify(hash, (sig,''))
     1611       
    14151612
    14161613# Oh god, Record_TXT how I hate thee.
    14171614class Record_TXT(tputil.FancyEqMixin, tputil.FancyStrMixin):
  • twisted/names/test/test_dns.py

    diff --git a/twisted/names/test/test_dns.py b/twisted/names/test/test_dns.py
    index e8a059f..965e77b 100644
    a b except ImportError: 
    1212    from StringIO import StringIO
    1313
    1414import struct
     15import re
    1516
    1617from twisted.python.failure import Failure
    1718from twisted.internet import address, task
    class EqualityTests(unittest.TestCase): 
    12141215            dns.Record_TXT(['foo', 'bar'], 10),
    12151216            dns.Record_TXT(['foo', 'bar'], 10),
    12161217            dns.Record_TXT(['foo', 'bar'], 100))
     1218
     1219class DnssecTests(unittest.TestCase):
     1220    """
     1221    Tests for the DNSSEC records & functions - RRSIG, NSEC, DNSKEY, DS
     1222    """
     1223
     1224    def _mk_rrsig(self, signame, signature, type=dns.A, algorithm=5, labels=1, original_ttl=2, expiration=0, inception=1, keytag=2):
     1225        rrsig = struct.pack('!HBBIIIH', type, algorithm, labels, original_ttl, expiration, inception, keytag)
     1226        for label in signame.split('.'):
     1227            rrsig += struct.pack('!B', len(label))
     1228            rrsig += label
     1229        rrsig += '\x00'
     1230        rrsig += signature
     1231
     1232        return rrsig
     1233
     1234    def test_rrsig_decode(self):
     1235        fields = {
     1236                'signame': 'foo.bar',
     1237                'signature': 'thesignature',
     1238                'type': dns.A,
     1239                'algorithm': 5,
     1240                'labels': 3,
     1241                'original_ttl': 300,
     1242                'expiration': 100000,
     1243                'inception': 200000,
     1244                'keytag': 5678,
     1245                }
     1246
     1247        rrsig = self._mk_rrsig(**fields)
     1248
     1249        rr = dns.Record_RRSIG()
     1250        rr.decode(StringIO(rrsig), len(rrsig))
     1251
     1252        self.assertEqual(
     1253                rr,
     1254                dns.Record_RRSIG(**fields)
     1255                )
     1256
     1257    def test_rrsig_encode(self):
     1258        rr = dns.Record_RRSIG(
     1259                signame='foo.bar',
     1260                signature='thesig',
     1261                type=dns.A,
     1262                algorithm=5,
     1263                labels=3,
     1264                original_ttl=0xbbbb,
     1265                expiration=0xaaaaaaaa,
     1266                inception=0xcccccccc,
     1267                keytag=0xfffe,
     1268                )
     1269        strio = StringIO()
     1270        rr.encode(strio)
     1271        val = strio.getvalue()
     1272
     1273        expect = str(
     1274                '\x00\x01'  # A
     1275                '\x05'      # algo=5
     1276                '\x03'      # labels=3
     1277                '\x00\x00\xbb\xbb'  # ttl
     1278                '\xaa\xaa\xaa\xaa'  # expiration
     1279                '\xcc\xcc\xcc\xcc'  # inception
     1280                '\xff\xfe'          # keytag
     1281                '\x03foo\x03bar\x00'    # foo.bar
     1282                'thesig'
     1283                )
     1284
     1285
     1286        self.assertEqual(
     1287                val,
     1288                expect,
     1289                )
     1290
     1291    def _mk_dnskey(self, flags=0, protocol=0, algorithm=0, key=''):
     1292        dnskey = struct.pack('!HBB', flags, protocol, algorithm)
     1293        dnskey += key
     1294        return dnskey
     1295
     1296    def test_dnskey_decode(self):
     1297        key = str(
     1298                'AwEAAcdYhgqRE+Z5NkzrKGl3fE6aTAtzMJfxWo8fK02j'
     1299                'niePZIEOmG75pGZAjUHh29iyfYHU394VewgNXQYjhryi'
     1300                'j4pdZ7U9DN/kpu6RNvcwPn6F+y/Hz5qsNTFZ/GIjU83J'
     1301                'RrVsU8fTpCY27pik6S5JRJ5l1nHVwptaTlSiLEL+FgQj'
     1302                )   # keytag==51561 - depends exactly on "fields" below!
     1303        key = key.decode('base64')
     1304
     1305        fields = {
     1306                'flags': 256,
     1307                'protocol': 3,
     1308                'algorithm': 5,
     1309                'key': key,
     1310                }
     1311
     1312        dnskey = self._mk_dnskey(**fields)
     1313
     1314        rr = dns.Record_DNSKEY()
     1315        rr.decode(StringIO(dnskey), len(dnskey))
     1316
     1317        self.assertEqual(
     1318                rr,
     1319                dns.Record_DNSKEY(**fields),
     1320                )
     1321        self.assertEqual(
     1322                rr.tag(),
     1323                51561,
     1324                )
     1325
     1326    def test_dnskey_encode(self):
     1327        rr = dns.Record_DNSKEY(flags=257, protocol=3, algorithm=5, key='thekey')
     1328        strio = StringIO()
     1329        rr.encode(strio)
     1330        val = strio.getvalue()
     1331
     1332        self.assertEqual(
     1333                val,
     1334                '\x01\x01\x03\x05thekey',
     1335                )
     1336
     1337    def _keytime2sec(self):
     1338        # convert 20100702204922 into seconds since 1970
     1339        m = re.match('(\d{4})(\d{2})(\d{2})(\d{2})(\d{2})(\d{2})')
     1340        y,m,d,h,m,s = [int(g) for g in m.groups()]
     1341        retu
     1342
     1343
     1344    def test_rrsig_validate(self):
     1345        # check that Record_RRSIG can validate signatures given the original RRs and the key
     1346        # FIXME: should really generate ourselves some test data, not using (old) public data
     1347
     1348        original = [
     1349                dns.RRHeader('www.ic.ac.uk', type=dns.A, cls=dns.IN, ttl=3600, payload=dns.Record_A('155.198.140.14'))
     1350                ]
     1351
     1352        key = str(
     1353                'AwEAAcdYhgqRE+Z5NkzrKGl3fE6aTAtzMJfxWo8fK02j'
     1354                'niePZIEOmG75pGZAjUHh29iyfYHU394VewgNXQYjhryi'
     1355                'j4pdZ7U9DN/kpu6RNvcwPn6F+y/Hz5qsNTFZ/GIjU83J'
     1356                'RrVsU8fTpCY27pik6S5JRJ5l1nHVwptaTlSiLEL+FgQj'
     1357                )   # keytag==51561
     1358        key = key.decode('base64')
     1359        keyrr = dns.Record_DNSKEY(flags=256, protocol=3, algorithm=5, key=key)
     1360
     1361        sigval = str(
     1362                'h574y3uK6FAWZcN5YdAiuZ8E4VOoZf0np7Fkd6kxzoj0'
     1363                'vLROww2MBERn66OyOZ+nWEojr3YyuVk04E0MUKe915Py'
     1364                'GY9dC49RoX/vwM5l25ScgtUJo7K4CgE9X8/7pIXMZ2Xn'
     1365                '/CNAPkqKSKywzLgZkENwOSVn3WSZdW6weqJ5e+k='
     1366                )
     1367        sigval = sigval.decode('base64')
     1368
     1369        sig = dns.Record_RRSIG(type=dns.A, algorithm=5, labels=4, original_ttl=3600,
     1370                expiration=1280695879,  # Aug  1, 2010 21:51:19
     1371                inception=1278103762,   # Jul  2, 2010 21:49:22
     1372                keytag=51561,
     1373                signature=sigval,
     1374                signame='ic.ac.uk',
     1375                )
     1376
     1377        isvalid = sig.validate(keyrr, original)
     1378        self.assertEqual(isvalid, True)
     1379
     1380
     1381