Index: twisted/enterprise/adbapi.py
===================================================================
RCS file: /cvs/Twisted/twisted/enterprise/adbapi.py,v
retrieving revision 1.50
diff -u -r1.50 adbapi.py
--- twisted/enterprise/adbapi.py	14 Jul 2003 01:31:44 -0000	1.50
+++ twisted/enterprise/adbapi.py	19 Jul 2003 00:20:50 -0000
@@ -25,7 +25,7 @@
 
 class Transaction:
     """
-    I am a lightweight wrapper for a database 'cursor' object.  I relay
+    I am a lightweight wrapper for a DB-API 'cursor' object.  I relay
     attribute access to the DB cursor.
     """
     _cursor = None
@@ -46,23 +46,20 @@
 class ConnectionPool(pb.Referenceable):
     """I represent a pool of connections to a DB-API 2.0 compliant database.
 
-    You can pass the noisy arg which determines whether informational
-    log messages are generated during the pool's operation.
+    You can pass cp_min, cp_max or both to set the minimum and maximum
+    number of connections that will be opened by the pool. You can pass
+    the noisy arg which determines whether informational log messages are
+    generated during the pool's operation.
     """
-    noisy = 1
 
-    # XXX - make the min and max attributes (and cp_min and cp_max
-    # kwargs to __init__) actually do something?
-    min = 3
-    max = 5
+    noisy = 1   # if true, generate informational log messages
+    min = 3     # minimum number of connections in pool
+    max = 5     # maximum number of connections in pool
 
     def __init__(self, dbapiName, *connargs, **connkw):
         """See ConnectionPool.__doc__
         """
         self.dbapiName = dbapiName
-        if self.noisy:
-            log.msg("Connecting to database: %s %s %s" %
-                    (dbapiName, connargs, connkw))
         self.dbapi = reflect.namedModule(dbapiName)
 
         if getattr(self.dbapi, 'apilevel', None) != '2.0':
@@ -74,10 +71,6 @@
         self.connargs = connargs
         self.connkw = connkw
 
-        import thread
-        self.threadID = thread.get_ident
-        self.connections = {}
-
         if connkw.has_key('cp_min'):
             self.min = connkw['cp_min']
             del connkw['cp_min']
@@ -90,7 +83,21 @@
             self.noisy = connkw['cp_noisy']
             del connkw['cp_noisy']
 
+        self.min = min(self.min, self.max)
+        self.max = max(self.min, self.max)
+
+        self.connections = {}  # all connections, hashed on thread id
+
+        # these are optional so import them here
+        from twisted.python import threadpool
+        import thread
+
+        self.threadID = thread.get_ident
+        self.threadpool = threadpool.ThreadPool(self.min, self.max,
+                                                self.connect)
+
         from twisted.internet import reactor
+        reactor.callWhenRunning(self.threadpool.start)
         self.shutdownID = reactor.addSystemEventTrigger('during', 'shutdown',
                                                         self.finalClose)
 
@@ -98,12 +105,12 @@
         """Interact with the database and return the result.
 
         The 'interaction' is a callable object which will be executed in a
-        pooled thread.  It will be passed an L{Transaction} object as an
-        argument (whose interface is identical to that of the database cursor
-        for your DB-API module of choice), and its results will be returned as
-        a Deferred.  If running the method raises an exception, the transaction
-        will be rolled back.  If the method returns a value, the transaction
-        will be committed.
+        thread using a pooled connection. It will be passed an L{Transaction}
+        object as an argument (whose interface is identical to that of the
+        database cursor for your DB-API module of choice), and its results
+        will be returned as a Deferred. If running the method raises an
+        exception, the transaction will be rolled back. If the method returns
+        a value, the transaction will be committed.
 
         @param interaction: a callable object whose first argument is
             L{adbapi.Transaction}.
@@ -117,33 +124,103 @@
         apply(self.interaction, (interaction,d.callback,d.errback,)+args, kw)
         return d
 
-    def __getstate__(self):
-        return {'dbapiName': self.dbapiName,
-                'noisy': self.noisy,
-                'min': self.min,
-                'max': self.max,
-                'connargs': self.connargs,
-                'connkw': self.connkw}
+    def runQuery(self, *args, **kw):
+        """Execute an SQL query and return the result.
 
-    def __setstate__(self, state):
-        self.__dict__ = state
-        apply(self.__init__, (self.dbapiName, )+self.connargs, self.connkw)
+        A DB-API cursor will will be invoked with cursor.execute(*args, **kw).
+        The exact nature of the arguments will depend on the specific flavor
+        of DB-API being used, but the first argument in *args be an SQL
+        statement. The result of a subsequent cursor.fetchall() will be
+        fired to the Deferred which is returned. If either the 'execute' or
+        'fetchall' methods raise an exception, the transaction will be rolled
+        back and a Failure returned.
+
+        @param *args,**kw: arguments to be passed to a DB-API cursor's
+        'execute' method.
+
+        @return: a Deferred which will fire the return value of a DB-API
+        cursor's 'fetchall' method, or a Failure.
+        """
+
+        d = defer.Deferred()
+        apply(self.query, (d.callback, d.errback)+args, kw)
+        return d
+
+    def runOperation(self, *args, **kw):
+        """Execute an SQL query and return None.
+
+        A DB-API cursor will will be invoked with cursor.execute(*args, **kw).
+        The exact nature of the arguments will depend on the specific flavor
+        of DB-API being used, but the first argument in *args will be an SQL
+        statement. This method will not attempt to fetch any results from the
+        query and is thus suitable for INSERT, DELETE, and other SQL statements
+        which do not return values. If the 'execute' method raises an exception,
+        the transaction will be rolled back and a Failure returned.
+
+        @param *args,**kw: arguments to be passed to a DB-API cursor's
+        'execute' method.
+
+        @return: a Deferred which will fire None or a Failure.
+        """
+
+        d = defer.Deferred()
+        apply(self.operation, (d.callback, d.errback)+args, kw)
+        return d
+
+    def close(self):
+        """Close all pool connections and shutdown the pool.
+
+        Connections will be closed even if they are in use!
+        """
+
+        from twisted.internet import reactor
+        reactor.removeSystemEventTrigger(self.shutdownID)
+        self.finalClose()
+
+    def finalClose(self):
+        """This should only be called by the shutdown trigger."""
+
+        self.threadpool.stop()
+        for connection in self.connections.values():
+            if self.noisy:
+                log.msg('adbapi closing: %s %s%s' % (self.dbapiName,
+                                                     self.connargs or '',
+                                                     self.connkw or ''))
+            connection.close()
+        self.connections.clear()
 
     def connect(self):
-        """Should be run in thread, blocks.
+        """Return a database connection when one becomes available. This method blocks and should be run in a thread from the internal threadpool.
 
         Don't call this method directly from non-threaded twisted code.
+
+        @return: a database connection from the pool.
         """
+
         tid = self.threadID()
         conn = self.connections.get(tid)
-        if not conn:
+        if conn is None:
+            if self.noisy:
+                log.msg('adbapi connecting: %s %s%s' % (self.dbapiName,
+                                                        self.connargs or '',
+                                                        self.connkw or ''))
             conn = apply(self.dbapi.connect, self.connargs, self.connkw)
             self.connections[tid] = conn
-            if self.noisy:
-                log.msg('adbapi connecting: %s %s%s' %
-                    ( self.dbapiName, self.connargs or '', self.connkw or ''))
         return conn
 
+    def _runInteraction(self, interaction, *args, **kw):
+        trans = Transaction(self, self.connect())
+        try:
+            result = apply(interaction, (trans,)+args, kw)
+            trans.close()
+            trans._connection.commit()
+            return result
+        except:
+            log.msg('Exception in SQL interaction. Rolling back.')
+            log.deferr()
+            trans._connection.rollback()
+            raise
+
     def _runQuery(self, args, kw):
         conn = self.connect()
         curs = conn.cursor()
@@ -154,32 +231,55 @@
             conn.commit()
             return result
         except:
+            log.msg('Exception in SQL query. Rolling back.')
+            log.deferr()
             conn.rollback()
             raise
 
     def _runOperation(self, args, kw):
         conn = self.connect()
         curs = conn.cursor()
-
         try:
             apply(curs.execute, args, kw)
-            result = None
             curs.close()
             conn.commit()
         except:
-            # XXX - failures aren't working here
+            log.msg('Exception in SQL operation. Rolling back.')
+            log.deferr()
             conn.rollback()
             raise
-        return result
+
+    def __getstate__(self):
+        return {'dbapiName': self.dbapiName,
+                'noisy': self.noisy,
+                'min': self.min,
+                'max': self.max,
+                'connargs': self.connargs,
+                'connkw': self.connkw}
+
+    def __setstate__(self, state):
+        self.__dict__ = state
+        apply(self.__init__, (self.dbapiName, )+self.connargs, self.connkw)
+
+    def _deferToThread(self, f, *args, **kwargs):
+        """Internal function.
+
+        Call f in one of the connection pool's threads.
+        """
+
+        d = defer.Deferred()
+        self.threadpool.callInThread(threads._putResultInDeferred,
+                                     d, f, args, kwargs)
+        return d
 
     def query(self, callback, errback, *args, **kw):
         # this will be deprecated ASAP
-        threads.deferToThread(self._runQuery, args, kw).addCallbacks(
+        self._deferToThread(self._runQuery, args, kw).addCallbacks(
             callback, errback)
 
     def operation(self, callback, errback, *args, **kw):
         # this will be deprecated ASAP
-        threads.deferToThread(self._runOperation, args, kw).addCallbacks(
+        self._deferToThread(self._runOperation, args, kw).addCallbacks(
             callback, errback)
 
     def synchronousOperation(self, *args, **kw):
@@ -187,48 +287,10 @@
 
     def interaction(self, interaction, callback, errback, *args, **kw):
         # this will be deprecated ASAP
-        apply(threads.deferToThread,
+        apply(self._deferToThread,
               (self._runInteraction, interaction) + args, kw).addCallbacks(
             callback, errback)
 
-    def runOperation(self, *args, **kw):
-        """Run a SQL statement and return a Deferred of result."""
-        d = defer.Deferred()
-        apply(self.operation, (d.callback,d.errback)+args, kw)
-        return d
-
-    def runQuery(self, *args, **kw):
-        """Run a read-only query and return a Deferred."""
-        d = defer.Deferred()
-        apply(self.query, (d.callback, d.errback)+args, kw)
-        return d
-
-    def _runInteraction(self, interaction, *args, **kw):
-        trans = Transaction(self, self.connect())
-        try:
-            result = apply(interaction, (trans,)+args, kw)
-        except:
-            log.msg('Exception in SQL interaction!  rolling back...')
-            log.deferr()
-            trans._connection.rollback()
-            raise
-        else:
-            trans._cursor.close()
-            trans._connection.commit()
-            return result
-
-    def close(self):
-        from twisted.internet import reactor
-        reactor.removeSystemEventTrigger(self.shutdownID)
-        self.finalClose()
-
-    def finalClose(self):
-        for connection in self.connections.values():
-            if self.noisy:
-                log.msg('adbapi closing: %s %s%s' % (self.dbapiName,
-                                                     self.connargs or '',
-                                                     self.connkw or ''))
-            connection.close()
 
 class Augmentation:
     '''A class which augments a database connector with some functionality.
@@ -242,11 +304,9 @@
 
     def __init__(self, dbpool):
         self.dbpool = dbpool
-        #self.createSchema()
 
     def __setstate__(self, state):
         self.__dict__ = state
-        #self.createSchema()
 
     def operationDone(self, done):
         """Example callback for database operation success.
Index: twisted/internet/base.py
===================================================================
RCS file: /cvs/Twisted/twisted/internet/base.py,v
retrieving revision 1.58
diff -u -r1.58 base.py
--- twisted/internet/base.py	10 Jul 2003 01:35:03 -0000	1.58
+++ twisted/internet/base.py	19 Jul 2003 00:20:51 -0000
@@ -163,6 +163,7 @@
     def __init__(self):
         self._eventTriggers = {}
         self._pendingTimedCalls = []
+        self.running = 0
         self.waker = None
         self.resolver = None
         self.usingThreads = 0
@@ -349,6 +350,14 @@
                                         "after":  2}[phase]
                                        ].remove(item)
 
+    def callWhenRunning(self, callable, *args, **kw):
+        """See twisted.internet.interfaces.IReactorCore.callWhenRunning.
+        """
+        if self.running:
+            callable(*args, **kw)
+        else:
+            self.addSystemEventTrigger('after', 'startup',
+                                       callable, *args, **kw)
 
     # IReactorTime
 
Index: twisted/internet/interfaces.py
===================================================================
RCS file: /cvs/Twisted/twisted/internet/interfaces.py,v
retrieving revision 1.88
diff -u -r1.88 interfaces.py
--- twisted/internet/interfaces.py	1 Jul 2003 05:07:24 -0000	1.88
+++ twisted/internet/interfaces.py	19 Jul 2003 00:20:54 -0000
@@ -402,7 +402,7 @@
 
         @param args: the arguments to call it with.
 
-        @param kw: they keyword arguments to call it with.
+        @param kw: the keyword arguments to call it with.
 
         @returns: An L{IDelayedCall} object that can be used to cancel
                   the scheduled call, by calling its C{cancel()} method.
@@ -589,6 +589,22 @@
         """Removes a trigger added with addSystemEventTrigger.
 
         @param triggerID: a value returned from addSystemEventTrigger.
+        """
+
+    def callWhenRunning(self, callable, *args, **kw):
+        """Call a function when the reactor is running.
+
+        If the reactor has not started, the callable will be scheduled
+        to run when it does start. Otherwise, the callable will be invoked
+        immediately.
+
+        @param callable: the callable object to call later.
+
+        @param args: the arguments to call it with.
+
+        @param kw: the keyword arguments to call it with.
+
+        @returns: None
         """
 
 
Index: twisted/python/threadpool.py
===================================================================
RCS file: /cvs/Twisted/twisted/python/threadpool.py,v
retrieving revision 1.20
diff -u -r1.20 threadpool.py
--- twisted/python/threadpool.py	1 May 2003 12:34:40 -0000	1.20
+++ twisted/python/threadpool.py	19 Jul 2003 00:20:55 -0000
@@ -54,26 +54,41 @@
     joined = 0
     started = 0
     workers = 0
-    
-    def __init__(self, minthreads=5, maxthreads=20):
+
+    def __init__(self, minthreads=5, maxthreads=20,
+                 init=None, *initargs, **initkw):
+        """Create a new threadpool.
+
+        @param minthreads: minimum number of threads in the pool
+
+        @param maxthreads: maximum number of threads in the pool
+
+        @param init: initialization function called from new threads
+
+        @param *initargs, **initkw: additional arguments to be passed to 'init'
+        """
+
         assert minthreads <= maxthreads, 'minimum is greater than maximum'
         self.q = Queue.Queue(0)
         self.min = minthreads
         self.max = maxthreads
+        self.init = init
+        self.initargs = initargs
+        self.initkw = initkw
         if runtime.platform.getType() != "java":
             self.waiters = []
         else:
             self.waiters = ThreadSafeList()
         self.threads = []
         self.working = {}
-    
+
     def start(self):
         """Start the threadpool.
         """
-        self.workers = self.min
+        self.workers = min(max(self.min, self.q.qsize()), self.max)
         self.joined = 0
         self.started = 1
-        for i in range(self.min):
+        for i in range(self.workers):
             name = "PoolThread-%s-%s" % (id(self), i)
             threading.Thread(target=self._worker, name=name).start()
 
@@ -86,7 +101,7 @@
         state['min'] = self.min
         state['max'] = self.max
         return state
-    
+
     def _startSomeWorkers(self):
         if not self.waiters:
             if self.workers < self.max:
@@ -124,9 +139,11 @@
         self.callInThread(self._runWithCallback, callback, errback, func, args, kw)
 
     def _worker(self):
+        if self.init:
+            self.init(*self.initargs, **self.initkw)
         ct = threading.currentThread()
         self.threads.append(ct)
-        
+
         while 1:
             self.waiters.append(ct)
             o = self.q.get()
Index: twisted/test/test_enterprise.py
===================================================================
RCS file: /cvs/Twisted/twisted/test/test_enterprise.py,v
retrieving revision 1.16
diff -u -r1.16 test_enterprise.py
--- twisted/test/test_enterprise.py	5 Jul 2003 21:04:32 -0000	1.16
+++ twisted/test/test_enterprise.py	19 Jul 2003 00:20:57 -0000
@@ -22,13 +22,15 @@
 import os
 import random
 
-from twisted.trial.util import deferredResult
 from twisted.enterprise.row import RowObject
 from twisted.enterprise.reflector import *
 from twisted.enterprise.xmlreflector import XMLReflector
 from twisted.enterprise.sqlreflector import SQLReflector
 from twisted.enterprise.adbapi import ConnectionPool
 from twisted.enterprise import util
+from twisted.internet import defer
+from twisted.trial.util import deferredResult, deferredError
+from twisted.python import log
 
 try: import gadfly
 except: gadfly = None
@@ -89,6 +91,12 @@
 )
 """
 
+simple_table_schema = """
+CREATE TABLE simple (
+  x integer
+)
+"""
+
 def randomizeRow(row, nullsOK=1, trailingSpacesOK=1):
     values = {}
     for name, type in row.rowColumns:
@@ -298,39 +306,94 @@
     DB_USER = 'twisted_test'
     DB_PASS = 'twisted_test'
 
+    can_rollback = 1
+
     reflectorClass = SQLReflector
 
     def createReflector(self):
         self.startDB()
         self.dbpool = self.makePool()
+        self.dbpool.threadpool.start() # since the reactor never really starts
         deferredResult(self.dbpool.runOperation(main_table_schema))
         deferredResult(self.dbpool.runOperation(child_table_schema))
+        deferredResult(self.dbpool.runOperation(simple_table_schema))
         return self.reflectorClass(self.dbpool, [TestRow, ChildRow])
 
     def destroyReflector(self):
         deferredResult(self.dbpool.runOperation('DROP TABLE testTable'))
         deferredResult(self.dbpool.runOperation('DROP TABLE childTable'))
+        deferredResult(self.dbpool.runOperation('DROP TABLE simple'))
         self.dbpool.close()
         self.stopDB()
 
-    def startDB(self): pass
-    def stopDB(self): pass
+    def testPool(self):
+        # make sure failures are raised correctly
+        deferredError(self.dbpool.runQuery("select * from NOTABLE"))
+        deferredError(self.dbpool.runOperation("delete from * from NOTABLE"))
+        deferredError(self.dbpool.runInteraction(self.bad_interaction))
+        log.flushErrors()
+
+        # verify simple table is empty
+        sql = "select count(1) from simple"
+        row = deferredResult(self.dbpool.runQuery(sql))
+        self.failUnless(int(row[0][0]) == 0, "Interaction not rolled back")
+
+        # add some rows to simple table (runOperation)
+        for i in range(self.count):
+            sql = "insert into simple(x) values(%d)" % i
+            deferredResult(self.dbpool.runOperation(sql))
+
+        # make sure they were added (runQuery)
+        sql = "select x from simple order by x";
+        rows = deferredResult(self.dbpool.runQuery(sql))
+        self.failUnless(len(rows) == self.count, "Wrong number of rows")
+        for i in range(self.count):
+            self.failUnless(len(rows[i]) == 1, "Wrong size row")
+            self.failUnless(rows[i][0] == i, "Values not returned.")
+
+        # runInteraction
+        deferredResult(self.dbpool.runInteraction(self.interaction))
+
+        # give the pool a workout
+        ds = []
+        for i in range(self.count):
+            sql = "select x from simple where x = %d" % i
+            ds.append(self.dbpool.runQuery(sql))
+        dlist = defer.DeferredList(ds, fireOnOneErrback=1)
+        result = deferredResult(dlist)
+        for i in range(self.count):
+            self.failUnless(result[i][1][0][0] == i, "Value not returned")
+
+        # now delete everything
+        ds = []
+        for i in range(self.count):
+            sql = "delete from simple where x = %d" % i
+            ds.append(self.dbpool.runOperation(sql))
+        dlist = defer.DeferredList(ds, fireOnOneErrback=1)
+        deferredResult(dlist)
+
+        # verify simple table is empty
+        sql = "select count(1) from simple"
+        row = deferredResult(self.dbpool.runQuery(sql))
+        self.failUnless(int(row[0][0]) == 0, "Interaction not rolled back")
+
+    def interaction(self, transaction):
+        transaction.execute("select x from simple order by x")
+        for i in range(self.count):
+            row = transaction.fetchone()
+            self.failUnless(len(row) == 1, "Wrong size row")
+            self.failUnless(row[0] == i, "Value not returned.")
+        # should test this, but gadfly throws an exception instead
+        #self.failUnless(transaction.fetchone() is None, "Too many rows")
+
+    def bad_interaction(self, transaction):
+        if self.can_rollback:
+            transaction.execute("insert into simple(x) values(0)")
 
+        transaction.execute("select * from NOTABLE")
 
-class SinglePool(ConnectionPool):
-    """A pool for just one connection at a time.
-    Remove this when ConnectionPool is fixed.
-    """
-
-    def __init__(self, connection):
-        self.connection = connection
-
-    def connect(self):
-        return self.connection
-
-    def close(self):
-        self.connection.close()
-        del self.connection
+    def startDB(self): pass
+    def stopDB(self): pass
 
 
 class NoSlashSQLReflector(SQLReflector):
@@ -346,6 +409,7 @@
     nullsOK = 0
     DB_DIR = "./gadflyDB"
     reflectorClass = NoSlashSQLReflector
+    can_rollback = 0
 
     def startDB(self):
         if not os.path.exists(self.DB_DIR): os.mkdir(self.DB_DIR)
@@ -359,7 +423,7 @@
         conn.close()
 
     def makePool(self):
-        return SinglePool(gadfly.gadfly(self.DB_NAME, self.DB_DIR))
+        return ConnectionPool('gadfly', self.DB_NAME, self.DB_DIR, cp_max=1)
 
 
 class SQLiteTestCase(SQLReflectorTestCase, unittest.TestCase):
@@ -375,7 +439,7 @@
         if os.path.exists(self.database): os.unlink(self.database)
 
     def makePool(self):
-        return SinglePool(sqlite.connect(database=self.database))
+        return ConnectionPool('sqlite', database=self.database, cp_max=1)
 
 
 class PostgresTestCase(SQLReflectorTestCase, unittest.TestCase):
@@ -384,7 +448,8 @@
 
     def makePool(self):
         return ConnectionPool('pyPgSQL.PgSQL', database=self.DB_NAME,
-                              user=self.DB_USER, password=self.DB_PASS)
+                              user=self.DB_USER, password=self.DB_PASS,
+                              cp_min=0)
 
 
 class MySQLTestCase(SQLReflectorTestCase, unittest.TestCase):
@@ -392,6 +457,7 @@
     """
 
     trailingSpacesOK = 0
+    can_rollback = 0
 
     def makePool(self):
         return ConnectionPool('MySQLdb', db=self.DB_NAME,
@@ -410,6 +476,8 @@
 
 
 if gadfly is None: GadflyTestCase.skip = 1
+elif not getattr(gadfly, 'connect', None): gadfly.connect = gadfly.gadfly
+
 if sqlite is None: SQLiteTestCase.skip = 1
 
 if PgSQL is None: PostgresTestCase.skip = 1
