Ticket #4397: pgadbapi.2.py

File pgadbapi.2.py, 18.1 KB (added by Jan Urbański, 12 years ago)

adapt to the changed psycopg2 interface

Line 
1# -*- test-case-name: twisted.test.test_pgadbapi -*-
2# Copyright (c) 2010 Twisted Matrix Laboratories.
3# See LICENSE for details.
4
5"""
6A Twisted wrapper for the asynchronous features of the PostgreSQL psycopg2
7driver.
8"""
9
10import psycopg2
11from psycopg2 import extensions
12from zope.interface import implements
13
14from twisted.internet import interfaces, reactor, defer
15from twisted.python import log
16
17
18class UnexpectedPollResult(Exception):
19    """
20    Polling returned an unexpected result.
21    """
22
23
24class _PollingMixin(object):
25    """
26    An object that wraps something pollable. It can take care of waiting for
27    the wrapped pollable to reach the OK state and adapts the pollable's
28    interface to U{interfaces.IReadWriteDescriptor}. It will forward all
29    attribute access that is has not been wrapped to the underlying
30    pollable. Useful as a mixin for classes that wrap a psycopg2 pollable
31    object.
32
33    @type reactor: A U{interfaces.IReactorFDSet} provider.
34    @ivar reactor: The reactor that the class will use to wait for the wrapped
35        pollable to reach the OK state.
36
37    @type prefix: C{str}
38    @ivar prefix: Prefix used during log formatting to indicate context.
39    """
40
41    implements(interfaces.IReadWriteDescriptor)
42
43    reactor = None
44    prefix = "pollable"
45    _pollingD = None
46
47    def pollable(self):
48        """
49        Return the pollable object. Subclasses should override this.
50
51        @return: A psycopg2 pollable.
52        """
53        raise NotImplementedError()
54
55    def poll(self):
56        """
57        Start polling the wrapped pollable.
58
59        @rtype: C{Deferred}
60        @return: A Deferred that will fire with an instance of this class when
61            the pollable reaches the OK state.
62        """
63        if not self._pollingD:
64            self._pollingD = defer.Deferred()
65        ret = self._pollingD
66
67        try:
68            self._pollingState = self.pollable().poll()
69        except:
70            d,  self._pollingD = self._pollingD, None
71            d.errback()
72            return ret
73
74        if self._pollingState == psycopg2.extensions.POLL_OK:
75            d, self._pollingD = self._pollingD, None
76            d.callback(self)
77        elif self._pollingState == psycopg2.extensions.POLL_WRITE:
78            self.reactor.addWriter(self)
79        elif self._pollingState == psycopg2.extensions.POLL_READ:
80            self.reactor.addReader(self)
81        else:
82            d,  self._pollingD = self._pollingD, None
83            d.errback(UnexpectedPollResult())
84
85        return ret
86
87    def doRead(self):
88        self.reactor.removeReader(self)
89        self.poll()
90
91    def doWrite(self):
92        self.reactor.removeWriter(self)
93        self.poll()
94
95    def logPrefix(self):
96        return self.prefix
97
98    def fileno(self):
99        if self.pollable():
100            return self.pollable().fileno()
101        else:
102            return -1
103
104    def connectionLost(self, reason):
105        if self._pollingD:
106            d,  self._pollingD = self._pollingD, None
107            d.errback(reason)
108
109    # forward all other access to the underlying connection
110    def __getattr__(self, name):
111        return getattr(self.pollable(), name)
112
113
114class Cursor(_PollingMixin):
115    """
116    A wrapper for a psycopg2 asynchronous cursor.
117
118    The wrapper will forward almost everything to the wrapped cursor, so the
119    usual DB-API interface can be used, but will take care of preventing
120    concurrent execution of asynchronous queries, which the PostgreSQL C
121    library does not support and will return Deferreds for some operations.
122    """
123
124    def __init__(self, cursor, connection):
125        self.reactor = connection.reactor
126        self.prefix = "cursor"
127
128        self._connection = connection
129        self._cursor = cursor
130
131    def pollable(self):
132        return self._cursor.connection
133
134    def execute(self, query, params=None):
135        """
136        A regular DB-API execute, but returns a Deferred.
137
138        @rtype: C{Deferred}
139        @return: A C{Deferred} that will fire with the results of the
140            execute().
141        """
142        return self._connection.lock.run(
143            self._doit, 'execute', query, params)
144
145    def callproc(self, procname, params=None):
146        """
147        A regular DB-API callproc, but returns a Deferred.
148
149        @rtype: C{Deferred}
150        @return: A C{Deferred} that will fire with the results of the
151            callproc().
152        """
153        return self._connection.lock.run(
154            self._doit, 'callproc', procname, params)
155
156    def _doit(self, name, *args, **kwargs):
157        try:
158            getattr(self._cursor, name)(*args, **kwargs)
159        except:
160            return defer.fail()
161
162        return self.poll()
163
164    def close(self):
165        _cursor, self._cursor = self._cursor, None
166        return _cursor.close()
167
168    def fileno(self):
169        if self._cursor and self._connection._connection:
170            return self._cursor.connection.fileno()
171        else:
172            return -1
173
174    def __getattr__(self, name):
175        # the pollable is the connection, but the wrapped object is the cursor
176        return getattr(self._cursor, name)
177
178
179class AlreadyConnected(Exception):
180    """
181    The database connection is already open.
182    """
183
184
185class RollbackFailed(Exception):
186    """
187    Rolling back the transaction failed, the connection might be in an unusable
188    state.
189
190    @type connection: L{Connection}
191    @ivar connection: The connection that failed to roll back its transaction.
192
193    @type originalFailure: L{failure.Failure}
194    @ivar originalFailure: The failure that caused the connection to try to roll back
195        the transaction.
196    """
197
198    def __init__(self, connection, originalFailure):
199        self.connection = connection
200        self.originalFailure = originalFailure
201
202    def __str__(self):
203        return "<RollbackFailed, original error: %s>" % self.originalFailure
204
205
206class Connection(_PollingMixin):
207    """
208    A wrapper for a psycopg2 asynchronous connection.
209
210    The wrapper forwards almost everything to the wrapped connection, but
211    provides additional methods for compatibility with C{adbapi.Connection}.
212
213    @type connectionFactory: Any callable.
214    @ivar connectionFactory: The factory used to produce connections.
215
216    @type cursorFactory: Any callable.
217    @ivar cursorFactory: The factory used to produce cursors.
218    """
219
220    connectionFactory = psycopg2.connect
221    cursorFactory = Cursor
222
223    def __init__(self, reactor=None):
224        if not reactor:
225            from twisted.internet import reactor
226        self.reactor = reactor
227        self.prefix = "connection"
228
229        # this lock will be used to prevent concurrent query execution
230        self.lock = defer.DeferredLock()
231        self._connection = None
232
233    def pollable(self):
234        return self._connection
235
236    def connect(self, *args, **kwargs):
237        """
238        Connect to the database.
239
240        Positional arguments will be passed to the psycop2.connect()
241        method. Use them to pass database names, usernames, passwords, etc.
242
243        @rtype: C{Deferred}
244        @returns: A Deferred that will fire when the connection is open.
245        """
246        if self._connection:
247            return defer.fail(AlreadyConnected())
248
249        kwargs['async'] = True
250        try:
251            self._connection = self.connectionFactory(*args, **kwargs)
252        except:
253            return defer.fail()
254
255        return self.poll()
256
257    def close(self):
258        """
259        Close the connection and disconnect from the database.
260        """
261        _connection, self._connection = self._connection, None
262        _connection.close()
263
264    def cursor(self):
265        """
266        Create an asynchronous cursor.
267        """
268        return self.cursorFactory(self._connection.cursor(), self)
269
270    def runQuery(self, *args, **kwargs):
271        """
272        Execute an SQL query and return the result.
273
274        An asynchronous cursor will be created and its execute() method will
275        be invoked with the provided *args and **kwargs. After the query
276        completes the cursor's fetchall() method will be called and the
277        returned Deferred will fire with the result.
278
279        The connection is always in autocommit mode, so the query will be run
280        in a one-off transaction. In case of errors a Failure will be returned.
281
282        @rtype: C{Deferred}
283        @return: A Deferred that will fire with the return value of the
284            cursor's fetchall() method.
285        """
286        c = self.cursor()
287        d = c.execute(*args, **kwargs)
288        return d.addCallback(lambda c: c.fetchall())
289
290    def runOperation(self, *args, **kwargs):
291        """
292        Execute an SQL query and return the result.
293
294        Identical to runQuery, but the cursor's fetchall() method will not be
295        called and instead None will be returned. It is intended for statements
296        that do not normally return values, like INSERT or DELETE.
297
298        @rtype: C{Deferred}
299        @return: A Deferred that will fire None.
300        """
301        c = self.cursor()
302        d = c.execute(*args, **kwargs)
303        return d.addCallback(lambda _: None)
304
305    def runInteraction(self, interaction, *args, **kwargs):
306        """
307        Run commands in a transaction and return the result.
308
309        The 'interaction' is a callable that will be passed a
310        C{pgadbapi.Cursor} object. Before calling 'interaction' a new
311        transaction will be started, so the callable can assume to be running
312        all its commands in a transaction. If 'interaction' returns a
313        C{Deferred} processing will wait for it to fire before proceeding.
314
315        After 'interaction' finishes work the transaction will be automatically
316        committed. If it raises an exception or returns a C{Failure} the
317        connection will be rolled back instead.
318
319        If committing the transaction fails it will be rolled back instead and
320        the C{Failure} obtained trying to commit will be returned.
321
322        If rolling back the transaction fails the C{Failure} obtained from the
323        rollback attempt will be logged and a C{RollbackFailed} failure will be
324        returned. The returned failure will contain references to the original
325        C{Failure} that caused the transaction to be rolled back and to the
326        C{Connection} in which that happend, so the user can take a decision
327        whether she still wants to be using it or just close it, because an
328        open transaction might have been left open in the database.
329
330        @type interaction: Any callable
331        @param interaction: A callable whose first argument is a
332            L{pgadbapi.Cursor}.
333
334        @rtype: C{Deferred}
335        @return: A Deferred that will file with the return value of
336            'interaction'.
337        """
338        c = self.cursor()
339        d = c.execute("begin")
340        d.addCallback(interaction, *args, **kwargs)
341
342        def commitAndPassthrough(ret, cursor):
343            e = cursor.execute("commit")
344            return e.addCallback(lambda _: ret)
345        def rollbackAndPassthrough(f, cursor):
346            # maybeDeferred in case cursor.execute raises a synchronous
347            # exception
348            e = defer.maybeDeferred(cursor.execute, "rollback")
349            def just_panic(rf):
350                log.err(rf)
351                return defer.fail(RollbackFailed(self, f))
352            # if rollback failed panic
353            e.addErrback(just_panic)
354            # reraise the original failure afterwards
355            return e.addCallback(lambda _: f)
356        d.addCallback(commitAndPassthrough, c)
357        d.addErrback(rollbackAndPassthrough, c)
358
359        return d
360
361class ConnectionPool(object):
362    """
363    A poor man's pool of L{pgadbapi.Connection} instances.
364
365    @type min: C{int}
366    @ivar min: The amount of connections that will be open at start. The pool
367        never opens or closes connections on its own.
368
369    @type connectionFactory: Any callable.
370    @ivar connectionFactory: The factory used to produce connections.
371    """
372
373    min = 3
374    connectionFactory = Connection
375    reactor = None
376
377    def __init__(self, _ignored, *connargs, **connkw):
378        """
379        Create a new connection pool.
380
381        Any positional or keyword arguments other than the first one and a
382        'min' keyword argument are passed to the L{Connection} when
383        connecting. Use these arguments to pass database names, usernames,
384        passwords, etc.
385
386        @type _ignored: Any object.
387        @param _ignored: Ignored, for L{adbapi.ConnectionPool} compatibility.
388        """
389        if not self.reactor:
390            from twisted.internet import reactor
391            self.reactor = reactor
392        # for adbapi compatibility, min can be passed in kwargs
393        if 'min' in connkw:
394            self.min = connkw.pop('min')
395        self.connargs = connargs
396        self.connkw = connkw
397        self.connections = set(
398            [self.connectionFactory(self.reactor) for _ in range(self.min)])
399
400        # to avoid checking out more connections than there are pooled in total
401        self._semaphore = defer.DeferredSemaphore(self.min)
402
403    def start(self):
404        """
405        Start the connection pool.
406
407        This will create as many connections as the pool's 'min' variable says.
408
409        @rtype: C{Deferred}
410        @return: A C{Deferred} that fires when all connection have succeeded.
411        """
412        d = defer.gatherResults([c.connect(*self.connargs, **self.connkw)
413                                 for c in self.connections])
414        return d.addCallback(lambda _: self)
415
416    def close(self):
417        """
418        Stop the pool.
419
420        Disconnect all connections.
421        """
422        for c in self.connections:
423            c.close()
424
425    def remove(self, connection):
426        """
427        Remove a connection from the pool.
428
429        Provided to be able to remove broken connections from the pool. The
430        caller should make sure the removed connection does not have queries
431        pending.
432
433        @type connection: An object produced by the pool's connection factory.
434        @param connection: The connection to be removed.
435        """
436        if not self.connections:
437            raise ValueError("Connection still in use")
438        self.connections.remove(connection)
439        self._semaphore.limit -= 1
440        self._semaphore.acquire()  # bleargh...
441
442    def add(self, connection):
443        """
444        Add a connection to the pool.
445
446        Provided to be able to extend the pool with new connections.
447
448        @type connection: An object compatible with those produce by the pool's
449            connection factory.
450        @param connection: The connection to be added.
451        """
452        self.connections.add(connection)
453        self._semaphore.limit += 1
454        self._semaphore.release() # uuuugh...
455
456    def _putBackAndPassthrough(self, result, connection):
457        self.connections.add(connection)
458        return result
459
460    def runQuery(self, *args, **kwargs):
461        """
462        Execute an SQL query and return the result.
463
464        An asynchronous cursor will be created from a randomly chosen pooled
465        connection and its execute() method will be invoked with the provided
466        *args and **kwargs. After the query completes the cursor's fetchall()
467        method will be called and the returned Deferred will fire with the
468        result.
469
470        The connection is always in autocommit mode, so the query will be run
471        in a one-off transaction. In case of errors a Failure will be returned.
472
473        @rtype: C{Deferred}
474        @return: A Deferred that will fire with the return value of the
475            cursor's fetchall() method.
476        """
477        return self._semaphore.run(self._runQuery, *args, **kwargs)
478
479    def _runQuery(self, *args, **kwargs):
480        c = self.connections.pop()
481        d = c.runQuery(*args, **kwargs)
482        return d.addBoth(self._putBackAndPassthrough, c)
483
484    def runOperation(self, *args, **kwargs):
485        """
486        Execute an SQL query and return the result.
487
488        Identical to runQuery, but the cursor's fetchall() method will not be
489        called and instead None will be returned. It is intended for statements
490        that do not normally return values, like INSERT or DELETE.
491
492        @rtype: C{Deferred}
493        @return: A Deferred that will fire None.
494        """
495        return self._semaphore.run(self._runOperation, *args, **kwargs)
496
497    def _runOperation(self, *args, **kwargs):
498        c = self.connections.pop()
499        d = c.runOperation(*args, **kwargs)
500        return d.addBoth(self._putBackAndPassthrough, c)
501
502    def runInteraction(self, interaction, *args, **kwargs):
503        """
504        Run commands in a transaction and return the result.
505
506        The 'interaction' is a callable that will be passed a
507        C{pgadbapi.Cursor} object. Before calling 'interaction' a new
508        transaction will be started, so the callable can assume to be running
509        all its commands in a transaction. If 'interaction' returns a
510        C{Deferred} processing will wait for it to fire before proceeding.
511
512        After 'interaction' finishes work the transaction will be automatically
513        committed. If it raises an exception or returns a C{Failure} the
514        connection will be rolled back instead.
515
516        If committing the transaction fails it will be rolled back instead and
517        the C{Failure} obtained trying to commit will be returned.
518
519        If rolling back the transaction fails the C{Failure} obtained from the
520        rollback attempt will be logged and a C{RollbackFailed} failure will be
521        returned. The returned failure will contain references to the original
522        C{Failure} that caused the transaction to be rolled back and to the
523        C{Connection} in which that happend, so the user can take a decision
524        whether she still wants to be using it or just close it, because an
525        open transaction might have been left open in the database.
526
527        @type interaction: Any callable
528        @param interaction: A callable whose first argument is a
529            L{pgadbapi.Cursor}.
530
531        @rtype: C{Deferred}
532        @return: A Deferred that will file with the return value of
533            'interaction'.
534        """
535        return self._semaphore.run(
536            self._runInteraction, interaction, *args, **kwargs)
537
538    def _runInteraction(self, interaction, *args, **kwargs):
539        c = self.connections.pop()
540        d = c.runInteraction(interaction, *args, **kwargs)
541        return d.addBoth(self._putBackAndPassthrough, c)