Index: twisted/enterprise/adbapi.py
===================================================================
RCS file: /cvs/Twisted/twisted/enterprise/adbapi.py,v
retrieving revision 1.50
diff -u -r1.50 adbapi.py
--- twisted/enterprise/adbapi.py	14 Jul 2003 01:31:44 -0000	1.50
+++ twisted/enterprise/adbapi.py	18 Jul 2003 08:25:47 -0000
@@ -25,7 +25,7 @@
 
 class Transaction:
     """
-    I am a lightweight wrapper for a database 'cursor' object.  I relay
+    I am a lightweight wrapper for a DB-API 'cursor' object.  I relay
     attribute access to the DB cursor.
     """
     _cursor = None
@@ -46,23 +46,21 @@
 class ConnectionPool(pb.Referenceable):
     """I represent a pool of connections to a DB-API 2.0 compliant database.
 
-    You can pass the noisy arg which determines whether informational
-    log messages are generated during the pool's operation.
+    You can pass cp_min, cp_max or both to set the minimum and maximum
+    number of connections that will be opened by the pool. You can pass
+    the noisy arg which determines whether informational log messages are
+    generated during the pool's operation.
     """
-    noisy = 1
 
-    # XXX - make the min and max attributes (and cp_min and cp_max
-    # kwargs to __init__) actually do something?
-    min = 3
-    max = 5
+    noisy = 1   # if true, generate informational log messages
+    min = 3     # minimum number of connections in pool
+    max = 5     # maximum number of connections in pool
+    started = 0 # 1 if threadpool has been started
 
     def __init__(self, dbapiName, *connargs, **connkw):
         """See ConnectionPool.__doc__
         """
         self.dbapiName = dbapiName
-        if self.noisy:
-            log.msg("Connecting to database: %s %s %s" %
-                    (dbapiName, connargs, connkw))
         self.dbapi = reflect.namedModule(dbapiName)
 
         if getattr(self.dbapi, 'apilevel', None) != '2.0':
@@ -74,10 +72,6 @@
         self.connargs = connargs
         self.connkw = connkw
 
-        import thread
-        self.threadID = thread.get_ident
-        self.connections = {}
-
         if connkw.has_key('cp_min'):
             self.min = connkw['cp_min']
             del connkw['cp_min']
@@ -90,7 +84,22 @@
             self.noisy = connkw['cp_noisy']
             del connkw['cp_noisy']
 
+        self.min = min(self.min, self.max)
+        self.max = max(self.min, self.max)
+
+        self.connections = {}  # all connections, hashed on thread id
+
+        # these are optional so import them here
+        from twisted.python import threadpool
+        import thread
+
+        self.threadID = thread.get_ident
+        self.threadpool = threadpool.ThreadPool(self.min, self.max)
+
+        # TODO: start up min connections
+
         from twisted.internet import reactor
+        reactor.callLater(0, self.threadpool.start)
         self.shutdownID = reactor.addSystemEventTrigger('during', 'shutdown',
                                                         self.finalClose)
 
@@ -98,12 +107,12 @@
         """Interact with the database and return the result.
 
         The 'interaction' is a callable object which will be executed in a
-        pooled thread.  It will be passed an L{Transaction} object as an
-        argument (whose interface is identical to that of the database cursor
-        for your DB-API module of choice), and its results will be returned as
-        a Deferred.  If running the method raises an exception, the transaction
-        will be rolled back.  If the method returns a value, the transaction
-        will be committed.
+        thread using a pooled connection. It will be passed an L{Transaction}
+        object as an argument (whose interface is identical to that of the
+        database cursor for your DB-API module of choice), and its results
+        will be returned as a Deferred. If running the method raises an
+        exception, the transaction will be rolled back. If the method returns
+        a value, the transaction will be committed.
 
         @param interaction: a callable object whose first argument is
             L{adbapi.Transaction}.
@@ -117,33 +126,103 @@
         apply(self.interaction, (interaction,d.callback,d.errback,)+args, kw)
         return d
 
-    def __getstate__(self):
-        return {'dbapiName': self.dbapiName,
-                'noisy': self.noisy,
-                'min': self.min,
-                'max': self.max,
-                'connargs': self.connargs,
-                'connkw': self.connkw}
+    def runQuery(self, *args, **kw):
+        """Execute an SQL query and return the result.
 
-    def __setstate__(self, state):
-        self.__dict__ = state
-        apply(self.__init__, (self.dbapiName, )+self.connargs, self.connkw)
+        A DB-API cursor will will be invoked with cursor.execute(*args, **kw).
+        The exact nature of the arguments will depend on the specific flavor
+        of DB-API being used, but the first argument in *args be an SQL
+        statement. The result of a subsequent cursor.fetchall() will be
+        fired to the Deferred which is returned. If either the 'execute' or
+        'fetchall' methods raise an exception, the transaction will be rolled
+        back and a Failure returned.
+
+        @param *args,**kw: arguments to be passed to a DB-API cursor's
+        'execute' method.
+
+        @return: a Deferred which will fire the return value of a DB-API
+        cursor's 'fetchall' method, or a Failure.
+        """
+
+        d = defer.Deferred()
+        apply(self.query, (d.callback, d.errback)+args, kw)
+        return d
+
+    def runOperation(self, *args, **kw):
+        """Execute an SQL query and return None.
+
+        A DB-API cursor will will be invoked with cursor.execute(*args, **kw).
+        The exact nature of the arguments will depend on the specific flavor
+        of DB-API being used, but the first argument in *args will be an SQL
+        statement. This method will not attempt to fetch any results from the
+        query and is thus suitable for INSERT, DELETE, and other SQL statements
+        which do not return values. If the 'execute' method raises an exception,
+        the transaction will be rolled back and a Failure returned.
+
+        @param *args,**kw: arguments to be passed to a DB-API cursor's
+        'execute' method.
+
+        @return: a Deferred which will fire None or a Failure.
+        """
+
+        d = defer.Deferred()
+        apply(self.operation, (d.callback, d.errback)+args, kw)
+        return d
+
+    def close(self):
+        """Close all pool connections and shutdown the pool.
+
+        Connections will be closed even if they are in use!
+        """
+
+        from twisted.internet import reactor
+        reactor.removeSystemEventTrigger(self.shutdownID)
+        self.finalClose()
+
+    def finalClose(self):
+        """This should only be called by the shutdown trigger."""
+
+        self.threadpool.stop()
+        for connection in self.connections.values():
+            if self.noisy:
+                log.msg('adbapi closing: %s %s%s' % (self.dbapiName,
+                                                     self.connargs or '',
+                                                     self.connkw or ''))
+            connection.close()
+        self.connections.clear()
 
     def connect(self):
-        """Should be run in thread, blocks.
+        """Return a database connection when one becomes available. This method blocks and should be run in a thread from the internal threadpool.
 
         Don't call this method directly from non-threaded twisted code.
+
+        @return: a database connection from the pool.
         """
+
         tid = self.threadID()
         conn = self.connections.get(tid)
-        if not conn:
+        if conn is None:
+            if self.noisy:
+                log.msg('adbapi connecting: %s %s%s' % (self.dbapiName,
+                                                        self.connargs or '',
+                                                        self.connkw or ''))
             conn = apply(self.dbapi.connect, self.connargs, self.connkw)
             self.connections[tid] = conn
-            if self.noisy:
-                log.msg('adbapi connecting: %s %s%s' %
-                    ( self.dbapiName, self.connargs or '', self.connkw or ''))
         return conn
 
+    def _runInteraction(self, interaction, *args, **kw):
+        trans = Transaction(self, self.connect())
+        try:
+            result = apply(interaction, (trans,)+args, kw)
+            trans.close()
+            trans._connection.commit()
+            return result
+        except:
+            log.msg('Exception in SQL interaction. Rolling back.')
+            log.deferr()
+            trans._connection.rollback()
+            raise
+
     def _runQuery(self, args, kw):
         conn = self.connect()
         curs = conn.cursor()
@@ -154,32 +233,55 @@
             conn.commit()
             return result
         except:
+            log.msg('Exception in SQL query. Rolling back.')
+            log.deferr()
             conn.rollback()
             raise
 
     def _runOperation(self, args, kw):
         conn = self.connect()
         curs = conn.cursor()
-
         try:
             apply(curs.execute, args, kw)
-            result = None
             curs.close()
             conn.commit()
         except:
-            # XXX - failures aren't working here
+            log.msg('Exception in SQL operation. Rolling back.')
+            log.deferr()
             conn.rollback()
             raise
-        return result
+
+    def __getstate__(self):
+        return {'dbapiName': self.dbapiName,
+                'noisy': self.noisy,
+                'min': self.min,
+                'max': self.max,
+                'connargs': self.connargs,
+                'connkw': self.connkw}
+
+    def __setstate__(self, state):
+        self.__dict__ = state
+        apply(self.__init__, (self.dbapiName, )+self.connargs, self.connkw)
+
+    def _deferToThread(self, f, *args, **kwargs):
+        """Internal function.
+
+        Call f in one of the connection pool's threads.
+        """
+
+        d = defer.Deferred()
+        self.threadpool.callInThread(threads._putResultInDeferred,
+                                     d, f, args, kwargs)
+        return d
 
     def query(self, callback, errback, *args, **kw):
         # this will be deprecated ASAP
-        threads.deferToThread(self._runQuery, args, kw).addCallbacks(
+        self._deferToThread(self._runQuery, args, kw).addCallbacks(
             callback, errback)
 
     def operation(self, callback, errback, *args, **kw):
         # this will be deprecated ASAP
-        threads.deferToThread(self._runOperation, args, kw).addCallbacks(
+        self._deferToThread(self._runOperation, args, kw).addCallbacks(
             callback, errback)
 
     def synchronousOperation(self, *args, **kw):
@@ -187,48 +289,10 @@
 
     def interaction(self, interaction, callback, errback, *args, **kw):
         # this will be deprecated ASAP
-        apply(threads.deferToThread,
+        apply(self._deferToThread,
               (self._runInteraction, interaction) + args, kw).addCallbacks(
             callback, errback)
 
-    def runOperation(self, *args, **kw):
-        """Run a SQL statement and return a Deferred of result."""
-        d = defer.Deferred()
-        apply(self.operation, (d.callback,d.errback)+args, kw)
-        return d
-
-    def runQuery(self, *args, **kw):
-        """Run a read-only query and return a Deferred."""
-        d = defer.Deferred()
-        apply(self.query, (d.callback, d.errback)+args, kw)
-        return d
-
-    def _runInteraction(self, interaction, *args, **kw):
-        trans = Transaction(self, self.connect())
-        try:
-            result = apply(interaction, (trans,)+args, kw)
-        except:
-            log.msg('Exception in SQL interaction!  rolling back...')
-            log.deferr()
-            trans._connection.rollback()
-            raise
-        else:
-            trans._cursor.close()
-            trans._connection.commit()
-            return result
-
-    def close(self):
-        from twisted.internet import reactor
-        reactor.removeSystemEventTrigger(self.shutdownID)
-        self.finalClose()
-
-    def finalClose(self):
-        for connection in self.connections.values():
-            if self.noisy:
-                log.msg('adbapi closing: %s %s%s' % (self.dbapiName,
-                                                     self.connargs or '',
-                                                     self.connkw or ''))
-            connection.close()
 
 class Augmentation:
     '''A class which augments a database connector with some functionality.
@@ -242,11 +306,9 @@
 
     def __init__(self, dbpool):
         self.dbpool = dbpool
-        #self.createSchema()
 
     def __setstate__(self, state):
         self.__dict__ = state
-        #self.createSchema()
 
     def operationDone(self, done):
         """Example callback for database operation success.
Index: twisted/python/threadpool.py
===================================================================
RCS file: /cvs/Twisted/twisted/python/threadpool.py,v
retrieving revision 1.20
diff -u -r1.20 threadpool.py
--- twisted/python/threadpool.py	1 May 2003 12:34:40 -0000	1.20
+++ twisted/python/threadpool.py	18 Jul 2003 08:25:48 -0000
@@ -70,10 +70,10 @@
     def start(self):
         """Start the threadpool.
         """
-        self.workers = self.min
+        self.workers = min(max(self.min, self.q.qsize()), self.max)
         self.joined = 0
         self.started = 1
-        for i in range(self.min):
+        for i in range(self.workers):
             name = "PoolThread-%s-%s" % (id(self), i)
             threading.Thread(target=self._worker, name=name).start()
 
@@ -107,7 +107,7 @@
         self.q.put(o)
         if self.started and not self.waiters:
             self._startSomeWorkers()
-    
+
     def _runWithCallback(self, callback, errback, func, args, kwargs):
         try:
             result = apply(func, args, kwargs)
Index: twisted/test/test_enterprise.py
===================================================================
RCS file: /cvs/Twisted/twisted/test/test_enterprise.py,v
retrieving revision 1.16
diff -u -r1.16 test_enterprise.py
--- twisted/test/test_enterprise.py	5 Jul 2003 21:04:32 -0000	1.16
+++ twisted/test/test_enterprise.py	18 Jul 2003 08:25:49 -0000
@@ -22,13 +22,15 @@
 import os
 import random
 
-from twisted.trial.util import deferredResult
 from twisted.enterprise.row import RowObject
 from twisted.enterprise.reflector import *
 from twisted.enterprise.xmlreflector import XMLReflector
 from twisted.enterprise.sqlreflector import SQLReflector
 from twisted.enterprise.adbapi import ConnectionPool
 from twisted.enterprise import util
+from twisted.internet import defer
+from twisted.trial.util import deferredResult, deferredError
+from twisted.python import log
 
 try: import gadfly
 except: gadfly = None
@@ -89,6 +91,12 @@
 )
 """
 
+simple_table_schema = """
+CREATE TABLE simple (
+  x integer
+)
+"""
+
 def randomizeRow(row, nullsOK=1, trailingSpacesOK=1):
     values = {}
     for name, type in row.rowColumns:
@@ -298,6 +306,8 @@
     DB_USER = 'twisted_test'
     DB_PASS = 'twisted_test'
 
+    can_rollback = 1
+
     reflectorClass = SQLReflector
 
     def createReflector(self):
@@ -305,32 +315,84 @@
         self.dbpool = self.makePool()
         deferredResult(self.dbpool.runOperation(main_table_schema))
         deferredResult(self.dbpool.runOperation(child_table_schema))
+        deferredResult(self.dbpool.runOperation(simple_table_schema))
         return self.reflectorClass(self.dbpool, [TestRow, ChildRow])
 
     def destroyReflector(self):
         deferredResult(self.dbpool.runOperation('DROP TABLE testTable'))
         deferredResult(self.dbpool.runOperation('DROP TABLE childTable'))
+        deferredResult(self.dbpool.runOperation('DROP TABLE simple'))
         self.dbpool.close()
         self.stopDB()
 
-    def startDB(self): pass
-    def stopDB(self): pass
+    def testPool(self):
+        # make sure failures are raised correctly
+        deferredError(self.dbpool.runQuery("select * from NOTABLE"))
+        deferredError(self.dbpool.runOperation("delete from * from NOTABLE"))
+        deferredError(self.dbpool.runInteraction(self.bad_interaction))
+        log.flushErrors()
+
+        # verify simple table is empty
+        sql = "select count(1) from simple"
+        row = deferredResult(self.dbpool.runQuery(sql))
+        self.failUnless(int(row[0][0]) == 0, "Interaction not rolled back")
+
+        # add some rows to simple table (runOperation)
+        for i in range(self.count):
+            sql = "insert into simple(x) values(%d)" % i
+            deferredResult(self.dbpool.runOperation(sql))
+
+        # make sure they were added (runQuery)
+        sql = "select x from simple order by x";
+        rows = deferredResult(self.dbpool.runQuery(sql))
+        self.failUnless(len(rows) == self.count, "Wrong number of rows")
+        for i in range(self.count):
+            self.failUnless(len(rows[i]) == 1, "Wrong size row")
+            self.failUnless(rows[i][0] == i, "Values not returned.")
+
+        # runInteraction
+        deferredResult(self.dbpool.runInteraction(self.interaction))
+
+        # give the pool a workout
+        ds = []
+        for i in range(self.count):
+            sql = "select x from simple where x = %d" % i
+            ds.append(self.dbpool.runQuery(sql))
+        dlist = defer.DeferredList(ds, fireOnOneErrback=1)
+        result = deferredResult(dlist)
+        for i in range(self.count):
+            self.failUnless(result[i][1][0][0] == i, "Value not returned")
+
+        # now delete everything
+        ds = []
+        for i in range(self.count):
+            sql = "delete from simple where x = %d" % i
+            ds.append(self.dbpool.runOperation(sql))
+        dlist = defer.DeferredList(ds, fireOnOneErrback=1)
+        deferredResult(dlist)
+
+        # verify simple table is empty
+        sql = "select count(1) from simple"
+        row = deferredResult(self.dbpool.runQuery(sql))
+        self.failUnless(int(row[0][0]) == 0, "Interaction not rolled back")
+
+    def interaction(self, transaction):
+        transaction.execute("select x from simple order by x")
+        for i in range(self.count):
+            row = transaction.fetchone()
+            self.failUnless(len(row) == 1, "Wrong size row")
+            self.failUnless(row[0] == i, "Value not returned.")
+        # should test this, but gadfly throws an exception instead
+        #self.failUnless(transaction.fetchone() is None, "Too many rows")
+
+    def bad_interaction(self, transaction):
+        if self.can_rollback:
+            transaction.execute("insert into simple(x) values(0)")
 
+        transaction.execute("select * from NOTABLE")
 
-class SinglePool(ConnectionPool):
-    """A pool for just one connection at a time.
-    Remove this when ConnectionPool is fixed.
-    """
-
-    def __init__(self, connection):
-        self.connection = connection
-
-    def connect(self):
-        return self.connection
-
-    def close(self):
-        self.connection.close()
-        del self.connection
+    def startDB(self): pass
+    def stopDB(self): pass
 
 
 class NoSlashSQLReflector(SQLReflector):
@@ -346,6 +408,7 @@
     nullsOK = 0
     DB_DIR = "./gadflyDB"
     reflectorClass = NoSlashSQLReflector
+    can_rollback = 0
 
     def startDB(self):
         if not os.path.exists(self.DB_DIR): os.mkdir(self.DB_DIR)
@@ -359,7 +422,7 @@
         conn.close()
 
     def makePool(self):
-        return SinglePool(gadfly.gadfly(self.DB_NAME, self.DB_DIR))
+        return ConnectionPool('gadfly', self.DB_NAME, self.DB_DIR, cp_max=1)
 
 
 class SQLiteTestCase(SQLReflectorTestCase, unittest.TestCase):
@@ -375,7 +438,7 @@
         if os.path.exists(self.database): os.unlink(self.database)
 
     def makePool(self):
-        return SinglePool(sqlite.connect(database=self.database))
+        return ConnectionPool('sqlite', database=self.database, cp_max=1)
 
 
 class PostgresTestCase(SQLReflectorTestCase, unittest.TestCase):
@@ -384,7 +447,8 @@
 
     def makePool(self):
         return ConnectionPool('pyPgSQL.PgSQL', database=self.DB_NAME,
-                              user=self.DB_USER, password=self.DB_PASS)
+                              user=self.DB_USER, password=self.DB_PASS,
+                              cp_min=0)
 
 
 class MySQLTestCase(SQLReflectorTestCase, unittest.TestCase):
@@ -392,6 +456,7 @@
     """
 
     trailingSpacesOK = 0
+    can_rollback = 0
 
     def makePool(self):
         return ConnectionPool('MySQLdb', db=self.DB_NAME,
@@ -410,6 +475,8 @@
 
 
 if gadfly is None: GadflyTestCase.skip = 1
+elif not getattr(gadfly, 'connect', None): gadfly.connect = gadfly.gadfly
+
 if sqlite is None: SQLiteTestCase.skip = 1
 
 if PgSQL is None: PostgresTestCase.skip = 1
