[Twisted-Python] Postgres client

Sune Kirkeby sune at mel.interspace.dk
Sun Jul 27 07:28:28 EDT 2003


Hello twisted people.

People have expressed interest in the postgres client I wrote for
twisted waaay back, so here it is.

It's got the packet parsing/formatting, and a weird sort of
interface, so afaik all it really needs is integration with the rest
of twisted (enterprise?), and postgres -> python type mappings
(/usr/include/postgresql/server/catalog/pg_type.h has the postgres
type identifiers listed).

The code is license-less as it is in the patch, but you may consider
it licensed under the LGPL, and I give Glyph non-exclusive
copyrights to the code. There, that should make it twisted-friendly.
I hope.

Enjoy,

/s

-- 
Sune Kirkeby | If humans were supposed to fly they'd
             | be born with stewardesses.
-------------- next part --------------
diff --exclude-from=diffignore -Naur Twisted-0.12.3/doc/examples/postgresql.py Twisted/doc/examples/postgresql.py
--- Twisted-0.12.3/doc/examples/postgresql.py	Thu Jan  1 01:00:00 1970
+++ Twisted/doc/examples/postgresql.py	Tue Jan  1 22:27:04 2002
@@ -0,0 +1,32 @@
+from twisted.enterprise import spock
+from twisted.protocols import postgresql
+from twisted.internet import main
+from twisted.internet import tcp
+
+class Consumer:
+    def __init__(self, prefix):
+        self.done = 0
+        self.count = 0
+        self.prefix = prefix
+
+    def receivedRow(self, types, data):
+        assert len(types) == len(data)
+        assert not self.done
+
+        self.count = self.count + 1
+        if self.count % 100 == 0:
+            print self.prefix, 'Got 100 rows'
+
+    def queryDone(self):
+        print self.prefix, 'Done (%d rows)' % self.count
+        self.done = 1
+
+pf = postgresql.PostgreSQLClientFactory('sune')
+pool = spock.ConnectionPool('localhost', 5432, pf)
+
+for i in range(8):
+    sql = 'SELECT * FROM foo WHERE id=%d' % i
+    consumer = Consumer('%d' % i)
+    pool.query(sql, consumer)
+
+main.run()
diff --exclude-from=diffignore -Naur Twisted-0.12.3/twisted/protocols/postgresql.py Twisted/twisted/protocols/postgresql.py
--- Twisted-0.12.3/twisted/protocols/postgresql.py	Thu Jan  1 01:00:00 1970
+++ Twisted/twisted/protocols/postgresql.py	Wed Jan  2 13:54:53 2002
@@ -0,0 +1,336 @@
+from socket import htons, htonl, ntohs, ntohl
+import struct
+import string
+import array
+
+from twisted.protocols import protocol
+
+class IncompletePacket:
+    pass
+
+def parseBytes(l, s):
+    if len(s) < l:
+        raise IncompletePacket()
+    return s[0:l], s[l:]
+
+def formatInt16(i):
+    return struct.pack('h', htons(i))
+
+def parseInt16(s):
+    l = struct.calcsize('h')
+    if len(s) < l:
+        raise IncompletePacket()
+    return ntohs(struct.unpack('h', s[0:l])[0]), s[l:]
+
+def formatInt32(i):
+    return struct.pack('i', htonl(i))
+
+def parseInt32(s):
+    l = struct.calcsize('i')
+    if len(s) < l:
+        raise IncompletePacket()
+    return ntohl(struct.unpack('i', s[0:l])[0]), s[l:]
+
+def formatLimString(l, s):
+    if len(s) < l:
+        return s + '\0' * (l - len(s))
+    else:
+        return s[0:l]
+
+def formatString(s):
+    return s + '\0'
+
+def parseString(s):
+    l = string.find(s, '\0')
+    if l < 0:
+        raise IncompletePacket
+    return s[0:l], s[l + 1:]
+
+
+class StartupPacket:
+    def __init__(self, user, database='', args='', tty=''):
+        self.user = user
+        self.database = database
+        self.args = args
+        self.tty = tty
+    def send(self, transport):
+        s  = formatInt32(296)         # packet length
+        s += formatInt16(2)           # protocol version, major
+        s += formatInt16(0)           # protocol version, minor
+        s += formatLimString(64, self.database)
+        s += formatLimString(32, self.user)
+        s += formatLimString(64, self.args)
+        s += formatLimString(64, '')  # unused
+        s += formatLimString(64, self.tty)
+        transport.write(s)
+
+class TerminatePacket:
+    def send(self, transport):
+        transport.write('X')
+
+class QueryPacket:
+    def __init__(self, query):
+        self.query = query
+
+    def send(self, transport):
+        transport.write('Q' + formatString(self.query))
+
+class CursorResponsePacket:
+    def __init__(self, name):
+        self.name = name
+
+class EmptyQueryResponsePacket:
+    pass
+
+class CompletedResponsePacket:
+    def __init__(self, cmd):
+        self.command = cmd
+
+class AuthenticationPacket:
+    def __init__(self, auth):
+        self.authentication = auth
+
+class BackendKeyDataPacket:
+    def __init__(self, pid, key):
+        self.process_id = pid
+        self.key = key
+
+class ReadyForQueryPacket:
+    pass
+
+class RowDescriptionPacket:
+    def __init__(self, columns):
+        self.columns = columns
+
+class AsciiRowPacket:
+    def __init__(self, columns):
+        self.columns = columns
+
+class ErrorPacket:
+    def __init__(self, message):
+        self.message = message
+
+class UnknownPacket(Exception):
+    pass
+
+
+def parsePacket(client, data):
+    orig_data = data
+    tag, data = parseBytes(1, data)
+
+    try:
+        if tag == 'E':
+            error, data = parseString(data)
+            return ErrorPacket(error), data
+
+        if tag == 'R':
+            auth, data = parseInt32(data)
+            return AuthenticationPacket(auth), data
+
+        if tag == 'K':
+            pid, data = parseInt32(data)
+            key, data = parseInt32(data)
+            return BackendKeyDataPacket(pid, key), data
+
+        if tag == 'Z':
+            return ReadyForQueryPacket(), data
+
+        if tag == 'P':
+            name, data = parseString(data)
+            return CursorResponsePacket(name), data
+
+        if tag == 'I':
+            unused, data = parseString(data)
+            return EmptyQueryResponsePacket(), data
+
+        if tag == 'T':
+            count, data = parseInt16(data)
+            columns = []
+            for i in range(count):
+                name, data = parseString(data)
+                type_oid, data = parseInt32(data)
+                type_size, data = parseInt16(data)
+                type_modifier, data = parseInt32(data)
+                columns.append((name, (type_oid, type_size, type_modifier)))
+            return RowDescriptionPacket(tuple(columns)), data
+
+        if tag == 'D':
+            field_count = len(client.row_description)
+            if field_count % 8 == 0:
+                bytes = field_count / 8
+            else:
+                bytes = field_count / 8 + 1
+            bitmap, data = parseBytes(bytes, data)
+            bitmap = array.array('B', bitmap)
+
+            fields = []
+            mask = 1 << 8
+            for i in range(field_count):
+                mask = mask >> 1
+                if mask == 0:
+                    mask = 1 << 7
+                    del bitmap[0]
+
+                if not bitmap[0] & mask:
+                    fields.append(None)
+
+                else:
+                    size, data = parseInt32(data)
+                    value, data = parseBytes(size - 4, data)
+                    fields.append(value)
+
+            return AsciiRowPacket(fields), data
+
+        if tag == 'C':
+            cmd, data = parseString(data)
+            return CompletedResponsePacket(cmd), data
+
+    except IncompletePacket:
+        return None, orig_data
+
+    raise UnknownPacket(orig_data)
+        
+
+class SilentObserver:
+    def connectionMade(self, client):
+        pass
+
+    def backendError(self, client, message):
+        pass
+
+    def protocolError(self, client, message):
+        pass
+
+    def readyForQuery(self, client):
+        pass
+
+    def connectionLost(self, client):
+        pass
+
+class SilentConsumer:
+    def __init__(self):
+        self.done = 0
+        self.types = None
+
+    def receivedHead(self, types):
+        assert self.types is None
+        self.types = types
+        
+    def receivedRow(self, data):
+        assert not self.types is None
+        assert len(self.types) == len(data)
+        assert not self.done
+
+    def queryDone(self):
+        self.done = 1
+
+class PostgreSQLClientFactory(protocol.ClientFactory):
+    def __init__(self, user, password='', args='', tty=''):
+        self.args = (user, password, args, tty)
+
+    def __getstate__(self):
+        return self.args
+    def __setstate__(self, state):
+        self.args = state
+        
+    def buildProtocol(self, conn):
+        p = apply(PostgreSQLClient, self.args)
+        p.factory = self
+        return p
+
+class PostgreSQLClient(protocol.Protocol):
+    def __init__(self, user, password='', args='', tty=''):
+        self.user = user
+        self.password = password
+        self.backend_args = args
+        self.backend_tty = tty
+
+        self.ready = 0
+        self.buffer = ''
+        self.row_description = None
+
+        self.observer = SilentObserver()
+        self.consumer = SilentConsumer()
+
+    def setObserver(self, ob):
+        if ob is None:
+            ob = SilentObserver()
+        self.observer = ob
+
+    def connectionMade(self):
+        self.ready = 1
+        self.sendPacket(StartupPacket(self.user))
+
+    def query(self, query, consumer=SilentConsumer()):
+        self.consumer = consumer
+        self.sendPacket(QueryPacket(query))
+
+    def terminate(self):
+        self.sendPacket(TerminatePacket())
+        self.transport.loseConnection()
+        self.observer.terminated(self)
+
+    def sendPacket(self, p):
+        if not self.ready:
+            raise 'Not ready for query.'
+
+        self.ready = 0
+        p.send(self.transport)
+
+    def dataReceived(self, data):
+        self.buffer = self.buffer + data
+        while len(self.buffer) > 0:
+            try:
+                packet, self.buffer = parsePacket(self, self.buffer)
+                
+                if packet is None:
+                    break
+
+                type = packet.__class__
+                if type is AuthenticationPacket:
+                    if packet.authentication == 0:
+                        self.observer.connectionMade(self)
+                    else:
+                        s = 'Got request for unsupported ' + \
+                            'authentication: %d' % packet.authentication
+                        self.observer.protocolError(self, s)
+                        self.terminate()
+
+                elif type is BackendKeyDataPacket:
+                    self.backend_key = packet
+
+                elif type is ReadyForQueryPacket:
+                    self.ready = 1
+                    self.consumer.queryDone()
+                    if self.ready:
+                        self.observer.readyForQuery(self)
+
+                elif type is RowDescriptionPacket:
+                    self.row_description = packet.columns
+                    self.consumer.receivedHead(self.row_description)
+                
+                elif type is ErrorPacket:
+                    self.observer.backendError(self, packet.message)
+                    self.consumer = SilentConsumer()
+
+                elif type is CursorResponsePacket:
+                    pass
+
+                elif type is AsciiRowPacket:
+                    self.consumer.receivedRow(packet.columns)
+
+                elif type is CompletedResponsePacket:
+                    pass
+
+                else:
+                    s = 'Got a "%s" I do not know what to do with!' % \
+                        packet.__class__.__name__
+                    self.observer.protocolError(self, s)
+
+            except UnknownPacket, e:
+                s = 'Unknown packet: %s' % `e.args[0]`
+                self.observer.protocolError(self, s)
+                self.terminate()
+                self.buffer = ''
+
+__all__ = ['PostgreSQLClient']


More information about the Twisted-Python mailing list