Index: twisted/enterprise/adbapi.py =================================================================== RCS file: /cvs/Twisted/twisted/enterprise/adbapi.py,v retrieving revision 1.50 diff -u -r1.50 adbapi.py --- twisted/enterprise/adbapi.py 14 Jul 2003 01:31:44 -0000 1.50 +++ twisted/enterprise/adbapi.py 19 Jul 2003 00:20:50 -0000 @@ -25,7 +25,7 @@ class Transaction: """ - I am a lightweight wrapper for a database 'cursor' object. I relay + I am a lightweight wrapper for a DB-API 'cursor' object. I relay attribute access to the DB cursor. """ _cursor = None @@ -46,23 +46,20 @@ class ConnectionPool(pb.Referenceable): """I represent a pool of connections to a DB-API 2.0 compliant database. - You can pass the noisy arg which determines whether informational - log messages are generated during the pool's operation. + You can pass cp_min, cp_max or both to set the minimum and maximum + number of connections that will be opened by the pool. You can pass + the noisy arg which determines whether informational log messages are + generated during the pool's operation. """ - noisy = 1 - # XXX - make the min and max attributes (and cp_min and cp_max - # kwargs to __init__) actually do something? - min = 3 - max = 5 + noisy = 1 # if true, generate informational log messages + min = 3 # minimum number of connections in pool + max = 5 # maximum number of connections in pool def __init__(self, dbapiName, *connargs, **connkw): """See ConnectionPool.__doc__ """ self.dbapiName = dbapiName - if self.noisy: - log.msg("Connecting to database: %s %s %s" % - (dbapiName, connargs, connkw)) self.dbapi = reflect.namedModule(dbapiName) if getattr(self.dbapi, 'apilevel', None) != '2.0': @@ -74,10 +71,6 @@ self.connargs = connargs self.connkw = connkw - import thread - self.threadID = thread.get_ident - self.connections = {} - if connkw.has_key('cp_min'): self.min = connkw['cp_min'] del connkw['cp_min'] @@ -90,7 +83,21 @@ self.noisy = connkw['cp_noisy'] del connkw['cp_noisy'] + self.min = min(self.min, self.max) + self.max = max(self.min, self.max) + + self.connections = {} # all connections, hashed on thread id + + # these are optional so import them here + from twisted.python import threadpool + import thread + + self.threadID = thread.get_ident + self.threadpool = threadpool.ThreadPool(self.min, self.max, + self.connect) + from twisted.internet import reactor + reactor.callWhenRunning(self.threadpool.start) self.shutdownID = reactor.addSystemEventTrigger('during', 'shutdown', self.finalClose) @@ -98,12 +105,12 @@ """Interact with the database and return the result. The 'interaction' is a callable object which will be executed in a - pooled thread. It will be passed an L{Transaction} object as an - argument (whose interface is identical to that of the database cursor - for your DB-API module of choice), and its results will be returned as - a Deferred. If running the method raises an exception, the transaction - will be rolled back. If the method returns a value, the transaction - will be committed. + thread using a pooled connection. It will be passed an L{Transaction} + object as an argument (whose interface is identical to that of the + database cursor for your DB-API module of choice), and its results + will be returned as a Deferred. If running the method raises an + exception, the transaction will be rolled back. If the method returns + a value, the transaction will be committed. @param interaction: a callable object whose first argument is L{adbapi.Transaction}. @@ -117,33 +124,103 @@ apply(self.interaction, (interaction,d.callback,d.errback,)+args, kw) return d - def __getstate__(self): - return {'dbapiName': self.dbapiName, - 'noisy': self.noisy, - 'min': self.min, - 'max': self.max, - 'connargs': self.connargs, - 'connkw': self.connkw} + def runQuery(self, *args, **kw): + """Execute an SQL query and return the result. - def __setstate__(self, state): - self.__dict__ = state - apply(self.__init__, (self.dbapiName, )+self.connargs, self.connkw) + A DB-API cursor will will be invoked with cursor.execute(*args, **kw). + The exact nature of the arguments will depend on the specific flavor + of DB-API being used, but the first argument in *args be an SQL + statement. The result of a subsequent cursor.fetchall() will be + fired to the Deferred which is returned. If either the 'execute' or + 'fetchall' methods raise an exception, the transaction will be rolled + back and a Failure returned. + + @param *args,**kw: arguments to be passed to a DB-API cursor's + 'execute' method. + + @return: a Deferred which will fire the return value of a DB-API + cursor's 'fetchall' method, or a Failure. + """ + + d = defer.Deferred() + apply(self.query, (d.callback, d.errback)+args, kw) + return d + + def runOperation(self, *args, **kw): + """Execute an SQL query and return None. + + A DB-API cursor will will be invoked with cursor.execute(*args, **kw). + The exact nature of the arguments will depend on the specific flavor + of DB-API being used, but the first argument in *args will be an SQL + statement. This method will not attempt to fetch any results from the + query and is thus suitable for INSERT, DELETE, and other SQL statements + which do not return values. If the 'execute' method raises an exception, + the transaction will be rolled back and a Failure returned. + + @param *args,**kw: arguments to be passed to a DB-API cursor's + 'execute' method. + + @return: a Deferred which will fire None or a Failure. + """ + + d = defer.Deferred() + apply(self.operation, (d.callback, d.errback)+args, kw) + return d + + def close(self): + """Close all pool connections and shutdown the pool. + + Connections will be closed even if they are in use! + """ + + from twisted.internet import reactor + reactor.removeSystemEventTrigger(self.shutdownID) + self.finalClose() + + def finalClose(self): + """This should only be called by the shutdown trigger.""" + + self.threadpool.stop() + for connection in self.connections.values(): + if self.noisy: + log.msg('adbapi closing: %s %s%s' % (self.dbapiName, + self.connargs or '', + self.connkw or '')) + connection.close() + self.connections.clear() def connect(self): - """Should be run in thread, blocks. + """Return a database connection when one becomes available. This method blocks and should be run in a thread from the internal threadpool. Don't call this method directly from non-threaded twisted code. + + @return: a database connection from the pool. """ + tid = self.threadID() conn = self.connections.get(tid) - if not conn: + if conn is None: + if self.noisy: + log.msg('adbapi connecting: %s %s%s' % (self.dbapiName, + self.connargs or '', + self.connkw or '')) conn = apply(self.dbapi.connect, self.connargs, self.connkw) self.connections[tid] = conn - if self.noisy: - log.msg('adbapi connecting: %s %s%s' % - ( self.dbapiName, self.connargs or '', self.connkw or '')) return conn + def _runInteraction(self, interaction, *args, **kw): + trans = Transaction(self, self.connect()) + try: + result = apply(interaction, (trans,)+args, kw) + trans.close() + trans._connection.commit() + return result + except: + log.msg('Exception in SQL interaction. Rolling back.') + log.deferr() + trans._connection.rollback() + raise + def _runQuery(self, args, kw): conn = self.connect() curs = conn.cursor() @@ -154,32 +231,55 @@ conn.commit() return result except: + log.msg('Exception in SQL query. Rolling back.') + log.deferr() conn.rollback() raise def _runOperation(self, args, kw): conn = self.connect() curs = conn.cursor() - try: apply(curs.execute, args, kw) - result = None curs.close() conn.commit() except: - # XXX - failures aren't working here + log.msg('Exception in SQL operation. Rolling back.') + log.deferr() conn.rollback() raise - return result + + def __getstate__(self): + return {'dbapiName': self.dbapiName, + 'noisy': self.noisy, + 'min': self.min, + 'max': self.max, + 'connargs': self.connargs, + 'connkw': self.connkw} + + def __setstate__(self, state): + self.__dict__ = state + apply(self.__init__, (self.dbapiName, )+self.connargs, self.connkw) + + def _deferToThread(self, f, *args, **kwargs): + """Internal function. + + Call f in one of the connection pool's threads. + """ + + d = defer.Deferred() + self.threadpool.callInThread(threads._putResultInDeferred, + d, f, args, kwargs) + return d def query(self, callback, errback, *args, **kw): # this will be deprecated ASAP - threads.deferToThread(self._runQuery, args, kw).addCallbacks( + self._deferToThread(self._runQuery, args, kw).addCallbacks( callback, errback) def operation(self, callback, errback, *args, **kw): # this will be deprecated ASAP - threads.deferToThread(self._runOperation, args, kw).addCallbacks( + self._deferToThread(self._runOperation, args, kw).addCallbacks( callback, errback) def synchronousOperation(self, *args, **kw): @@ -187,48 +287,10 @@ def interaction(self, interaction, callback, errback, *args, **kw): # this will be deprecated ASAP - apply(threads.deferToThread, + apply(self._deferToThread, (self._runInteraction, interaction) + args, kw).addCallbacks( callback, errback) - def runOperation(self, *args, **kw): - """Run a SQL statement and return a Deferred of result.""" - d = defer.Deferred() - apply(self.operation, (d.callback,d.errback)+args, kw) - return d - - def runQuery(self, *args, **kw): - """Run a read-only query and return a Deferred.""" - d = defer.Deferred() - apply(self.query, (d.callback, d.errback)+args, kw) - return d - - def _runInteraction(self, interaction, *args, **kw): - trans = Transaction(self, self.connect()) - try: - result = apply(interaction, (trans,)+args, kw) - except: - log.msg('Exception in SQL interaction! rolling back...') - log.deferr() - trans._connection.rollback() - raise - else: - trans._cursor.close() - trans._connection.commit() - return result - - def close(self): - from twisted.internet import reactor - reactor.removeSystemEventTrigger(self.shutdownID) - self.finalClose() - - def finalClose(self): - for connection in self.connections.values(): - if self.noisy: - log.msg('adbapi closing: %s %s%s' % (self.dbapiName, - self.connargs or '', - self.connkw or '')) - connection.close() class Augmentation: '''A class which augments a database connector with some functionality. @@ -242,11 +304,9 @@ def __init__(self, dbpool): self.dbpool = dbpool - #self.createSchema() def __setstate__(self, state): self.__dict__ = state - #self.createSchema() def operationDone(self, done): """Example callback for database operation success. Index: twisted/internet/base.py =================================================================== RCS file: /cvs/Twisted/twisted/internet/base.py,v retrieving revision 1.58 diff -u -r1.58 base.py --- twisted/internet/base.py 10 Jul 2003 01:35:03 -0000 1.58 +++ twisted/internet/base.py 19 Jul 2003 00:20:51 -0000 @@ -163,6 +163,7 @@ def __init__(self): self._eventTriggers = {} self._pendingTimedCalls = [] + self.running = 0 self.waker = None self.resolver = None self.usingThreads = 0 @@ -349,6 +350,14 @@ "after": 2}[phase] ].remove(item) + def callWhenRunning(self, callable, *args, **kw): + """See twisted.internet.interfaces.IReactorCore.callWhenRunning. + """ + if self.running: + callable(*args, **kw) + else: + self.addSystemEventTrigger('after', 'startup', + callable, *args, **kw) # IReactorTime Index: twisted/internet/interfaces.py =================================================================== RCS file: /cvs/Twisted/twisted/internet/interfaces.py,v retrieving revision 1.88 diff -u -r1.88 interfaces.py --- twisted/internet/interfaces.py 1 Jul 2003 05:07:24 -0000 1.88 +++ twisted/internet/interfaces.py 19 Jul 2003 00:20:54 -0000 @@ -402,7 +402,7 @@ @param args: the arguments to call it with. - @param kw: they keyword arguments to call it with. + @param kw: the keyword arguments to call it with. @returns: An L{IDelayedCall} object that can be used to cancel the scheduled call, by calling its C{cancel()} method. @@ -589,6 +589,22 @@ """Removes a trigger added with addSystemEventTrigger. @param triggerID: a value returned from addSystemEventTrigger. + """ + + def callWhenRunning(self, callable, *args, **kw): + """Call a function when the reactor is running. + + If the reactor has not started, the callable will be scheduled + to run when it does start. Otherwise, the callable will be invoked + immediately. + + @param callable: the callable object to call later. + + @param args: the arguments to call it with. + + @param kw: the keyword arguments to call it with. + + @returns: None """ Index: twisted/python/threadpool.py =================================================================== RCS file: /cvs/Twisted/twisted/python/threadpool.py,v retrieving revision 1.20 diff -u -r1.20 threadpool.py --- twisted/python/threadpool.py 1 May 2003 12:34:40 -0000 1.20 +++ twisted/python/threadpool.py 19 Jul 2003 00:20:55 -0000 @@ -54,26 +54,41 @@ joined = 0 started = 0 workers = 0 - - def __init__(self, minthreads=5, maxthreads=20): + + def __init__(self, minthreads=5, maxthreads=20, + init=None, *initargs, **initkw): + """Create a new threadpool. + + @param minthreads: minimum number of threads in the pool + + @param maxthreads: maximum number of threads in the pool + + @param init: initialization function called from new threads + + @param *initargs, **initkw: additional arguments to be passed to 'init' + """ + assert minthreads <= maxthreads, 'minimum is greater than maximum' self.q = Queue.Queue(0) self.min = minthreads self.max = maxthreads + self.init = init + self.initargs = initargs + self.initkw = initkw if runtime.platform.getType() != "java": self.waiters = [] else: self.waiters = ThreadSafeList() self.threads = [] self.working = {} - + def start(self): """Start the threadpool. """ - self.workers = self.min + self.workers = min(max(self.min, self.q.qsize()), self.max) self.joined = 0 self.started = 1 - for i in range(self.min): + for i in range(self.workers): name = "PoolThread-%s-%s" % (id(self), i) threading.Thread(target=self._worker, name=name).start() @@ -86,7 +101,7 @@ state['min'] = self.min state['max'] = self.max return state - + def _startSomeWorkers(self): if not self.waiters: if self.workers < self.max: @@ -124,9 +139,11 @@ self.callInThread(self._runWithCallback, callback, errback, func, args, kw) def _worker(self): + if self.init: + self.init(*self.initargs, **self.initkw) ct = threading.currentThread() self.threads.append(ct) - + while 1: self.waiters.append(ct) o = self.q.get() Index: twisted/test/test_enterprise.py =================================================================== RCS file: /cvs/Twisted/twisted/test/test_enterprise.py,v retrieving revision 1.16 diff -u -r1.16 test_enterprise.py --- twisted/test/test_enterprise.py 5 Jul 2003 21:04:32 -0000 1.16 +++ twisted/test/test_enterprise.py 19 Jul 2003 00:20:57 -0000 @@ -22,13 +22,15 @@ import os import random -from twisted.trial.util import deferredResult from twisted.enterprise.row import RowObject from twisted.enterprise.reflector import * from twisted.enterprise.xmlreflector import XMLReflector from twisted.enterprise.sqlreflector import SQLReflector from twisted.enterprise.adbapi import ConnectionPool from twisted.enterprise import util +from twisted.internet import defer +from twisted.trial.util import deferredResult, deferredError +from twisted.python import log try: import gadfly except: gadfly = None @@ -89,6 +91,12 @@ ) """ +simple_table_schema = """ +CREATE TABLE simple ( + x integer +) +""" + def randomizeRow(row, nullsOK=1, trailingSpacesOK=1): values = {} for name, type in row.rowColumns: @@ -298,39 +306,94 @@ DB_USER = 'twisted_test' DB_PASS = 'twisted_test' + can_rollback = 1 + reflectorClass = SQLReflector def createReflector(self): self.startDB() self.dbpool = self.makePool() + self.dbpool.threadpool.start() # since the reactor never really starts deferredResult(self.dbpool.runOperation(main_table_schema)) deferredResult(self.dbpool.runOperation(child_table_schema)) + deferredResult(self.dbpool.runOperation(simple_table_schema)) return self.reflectorClass(self.dbpool, [TestRow, ChildRow]) def destroyReflector(self): deferredResult(self.dbpool.runOperation('DROP TABLE testTable')) deferredResult(self.dbpool.runOperation('DROP TABLE childTable')) + deferredResult(self.dbpool.runOperation('DROP TABLE simple')) self.dbpool.close() self.stopDB() - def startDB(self): pass - def stopDB(self): pass + def testPool(self): + # make sure failures are raised correctly + deferredError(self.dbpool.runQuery("select * from NOTABLE")) + deferredError(self.dbpool.runOperation("delete from * from NOTABLE")) + deferredError(self.dbpool.runInteraction(self.bad_interaction)) + log.flushErrors() + + # verify simple table is empty + sql = "select count(1) from simple" + row = deferredResult(self.dbpool.runQuery(sql)) + self.failUnless(int(row[0][0]) == 0, "Interaction not rolled back") + + # add some rows to simple table (runOperation) + for i in range(self.count): + sql = "insert into simple(x) values(%d)" % i + deferredResult(self.dbpool.runOperation(sql)) + + # make sure they were added (runQuery) + sql = "select x from simple order by x"; + rows = deferredResult(self.dbpool.runQuery(sql)) + self.failUnless(len(rows) == self.count, "Wrong number of rows") + for i in range(self.count): + self.failUnless(len(rows[i]) == 1, "Wrong size row") + self.failUnless(rows[i][0] == i, "Values not returned.") + + # runInteraction + deferredResult(self.dbpool.runInteraction(self.interaction)) + + # give the pool a workout + ds = [] + for i in range(self.count): + sql = "select x from simple where x = %d" % i + ds.append(self.dbpool.runQuery(sql)) + dlist = defer.DeferredList(ds, fireOnOneErrback=1) + result = deferredResult(dlist) + for i in range(self.count): + self.failUnless(result[i][1][0][0] == i, "Value not returned") + + # now delete everything + ds = [] + for i in range(self.count): + sql = "delete from simple where x = %d" % i + ds.append(self.dbpool.runOperation(sql)) + dlist = defer.DeferredList(ds, fireOnOneErrback=1) + deferredResult(dlist) + + # verify simple table is empty + sql = "select count(1) from simple" + row = deferredResult(self.dbpool.runQuery(sql)) + self.failUnless(int(row[0][0]) == 0, "Interaction not rolled back") + + def interaction(self, transaction): + transaction.execute("select x from simple order by x") + for i in range(self.count): + row = transaction.fetchone() + self.failUnless(len(row) == 1, "Wrong size row") + self.failUnless(row[0] == i, "Value not returned.") + # should test this, but gadfly throws an exception instead + #self.failUnless(transaction.fetchone() is None, "Too many rows") + + def bad_interaction(self, transaction): + if self.can_rollback: + transaction.execute("insert into simple(x) values(0)") + transaction.execute("select * from NOTABLE") -class SinglePool(ConnectionPool): - """A pool for just one connection at a time. - Remove this when ConnectionPool is fixed. - """ - - def __init__(self, connection): - self.connection = connection - - def connect(self): - return self.connection - - def close(self): - self.connection.close() - del self.connection + def startDB(self): pass + def stopDB(self): pass class NoSlashSQLReflector(SQLReflector): @@ -346,6 +409,7 @@ nullsOK = 0 DB_DIR = "./gadflyDB" reflectorClass = NoSlashSQLReflector + can_rollback = 0 def startDB(self): if not os.path.exists(self.DB_DIR): os.mkdir(self.DB_DIR) @@ -359,7 +423,7 @@ conn.close() def makePool(self): - return SinglePool(gadfly.gadfly(self.DB_NAME, self.DB_DIR)) + return ConnectionPool('gadfly', self.DB_NAME, self.DB_DIR, cp_max=1) class SQLiteTestCase(SQLReflectorTestCase, unittest.TestCase): @@ -375,7 +439,7 @@ if os.path.exists(self.database): os.unlink(self.database) def makePool(self): - return SinglePool(sqlite.connect(database=self.database)) + return ConnectionPool('sqlite', database=self.database, cp_max=1) class PostgresTestCase(SQLReflectorTestCase, unittest.TestCase): @@ -384,7 +448,8 @@ def makePool(self): return ConnectionPool('pyPgSQL.PgSQL', database=self.DB_NAME, - user=self.DB_USER, password=self.DB_PASS) + user=self.DB_USER, password=self.DB_PASS, + cp_min=0) class MySQLTestCase(SQLReflectorTestCase, unittest.TestCase): @@ -392,6 +457,7 @@ """ trailingSpacesOK = 0 + can_rollback = 0 def makePool(self): return ConnectionPool('MySQLdb', db=self.DB_NAME, @@ -410,6 +476,8 @@ if gadfly is None: GadflyTestCase.skip = 1 +elif not getattr(gadfly, 'connect', None): gadfly.connect = gadfly.gadfly + if sqlite is None: SQLiteTestCase.skip = 1 if PgSQL is None: PostgresTestCase.skip = 1