Skip to content

Commit ac47727

Browse files
committed
Support optional iterator protocol on cursors (#38)
Use the original iterator if the cursor is already iterable, make it iterable otherwise.
1 parent b147745 commit ac47727

File tree

2 files changed

+38
-0
lines changed

2 files changed

+38
-0
lines changed

dbutils/steady_db.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,14 @@ def __exit__(self, *exc):
542542
"""Exit the runtime context for the cursor object."""
543543
self.close()
544544

545+
def __iter__(self):
546+
"""Make cursor compatible to the iteration protocol."""
547+
cursor = self._cursor
548+
try: # use iterator provided by original cursor
549+
return iter(cursor)
550+
except TypeError: # create iterator if not provided
551+
return iter(cursor.fetchone, None)
552+
545553
def setinputsizes(self, sizes):
546554
"""Store input sizes in case cursor needs to be reopened."""
547555
self._inputsizes = sizes

tests/test_steady_db.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,36 @@ def test_cursor_context_handler(self):
273273
self.assertEqual(cursor.fetchone(), 'test')
274274
self.assertEqual(db._con.open_cursors, 0)
275275

276+
def test_cursor_as_iterator_provided(self):
277+
db = SteadyDBconnect(
278+
dbapi, 0, None, None, None, True,
279+
'SteadyDBTestDB', user='SteadyDBTestUser')
280+
self.assertEqual(db._con.open_cursors, 0)
281+
cursor = db.cursor()
282+
self.assertEqual(db._con.open_cursors, 1)
283+
cursor.execute('select test')
284+
_cursor = cursor._cursor
285+
try:
286+
assert not hasattr(_cursor, 'iter')
287+
_cursor.__iter__ = lambda: ['test-iter']
288+
assert list(iter(cursor)) == ['test']
289+
finally:
290+
del _cursor.__iter__
291+
cursor.close()
292+
self.assertEqual(db._con.open_cursors, 0)
293+
294+
def test_cursor_as_iterator_created(self):
295+
db = SteadyDBconnect(
296+
dbapi, 0, None, None, None, True,
297+
'SteadyDBTestDB', user='SteadyDBTestUser')
298+
self.assertEqual(db._con.open_cursors, 0)
299+
cursor = db.cursor()
300+
self.assertEqual(db._con.open_cursors, 1)
301+
cursor.execute('select test')
302+
assert list(iter(cursor)) == ['test']
303+
cursor.close()
304+
self.assertEqual(db._con.open_cursors, 0)
305+
276306
def test_connection_creator_function(self):
277307
db1 = SteadyDBconnect(
278308
dbapi, 0, None, None, None, True,

0 commit comments

Comments
 (0)