Source code for destral.patch
import logging
logger = logging.getLogger(__name__)
class RestorePatchedRegisterAll(object):
def __enter__(self):
import report
logger.info('Saving original register_all {0}'.format(
id(report.interface.register_all)
))
self.orig = report.interface.register_all
return self
def __exit__(self, exc_type, exc_val, exc_tb):
import report
if id(report.interface.register_all) != id(self.orig):
logger.info('Restoring register_all {0} to {1}'.format(
id(report.interface.register_all), id(self.orig)
))
report.interface.register_all = self.orig
class PatchedCursor(object):
def __init__(self, cursor):
self.cursor = cursor
def commit(self):
return True
def rollback(self, savepoint=False):
return True
def close(self):
return True
def __getattr__(self, item):
return getattr(self.cursor, item)
[docs]class PatchedConnection(object):
"""Patched connection wapper to return the same cursor.
This is useful when some method inside a testing mehtod creates new
cursors.
:param connection: Original connection
:param cursor: Original cursor
"""
def __init__(self, connection, cursor):
self._connection = connection
self._cursor = PatchedCursor(cursor)
def __getattr__(self, item):
return getattr(self._connection, item)
[docs] def cursor(self, serialized=False):
"""Wrapped function to return the same cursor
"""
return self._cursor
[docs]class PatchNewCursors(object):
"""Util to patch creation of new cursor.
This will always return the cursor created by Transaction
"""
@staticmethod
def db_connect(db_name):
from destral.transaction import Transaction
import sql_db
cursor = Transaction().cursor
conn = sql_db.Connection(sql_db._Pool, db_name)
return PatchedConnection(conn, cursor)
def __enter__(self):
import sql_db
logger.info('Patching creation of new cursors')
self.orig = sql_db.db_connect
sql_db.db_connect = PatchNewCursors.db_connect
def __exit__(self, exc_type, exc_val, exc_tb):
import sql_db
logger.info('Unpatching creation of new cursors')
sql_db.db_connect = self.orig