You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
366 lines
13 KiB
366 lines
13 KiB
# pylint: disable-msg=C0111
|
|
|
|
import re, time, traceback
|
|
import common
|
|
from autotest_lib.client.common_lib import global_config
|
|
|
|
RECONNECT_FOREVER = object()
|
|
|
|
_DB_EXCEPTIONS = ('DatabaseError', 'OperationalError', 'ProgrammingError')
|
|
_GLOBAL_CONFIG_NAMES = {
|
|
'username' : 'user',
|
|
'db_name' : 'database',
|
|
}
|
|
|
|
def _copy_exceptions(source, destination):
|
|
for exception_name in _DB_EXCEPTIONS:
|
|
try:
|
|
setattr(destination, exception_name,
|
|
getattr(source, exception_name))
|
|
except AttributeError:
|
|
# Under the django backend:
|
|
# Django 1.3 does not have OperationalError and ProgrammingError.
|
|
# Let's just mock these classes with the base DatabaseError.
|
|
setattr(destination, exception_name,
|
|
getattr(source, 'DatabaseError'))
|
|
|
|
|
|
class _GenericBackend(object):
|
|
def __init__(self, database_module):
|
|
self._database_module = database_module
|
|
self._connection = None
|
|
self._cursor = None
|
|
self.rowcount = None
|
|
_copy_exceptions(database_module, self)
|
|
|
|
|
|
def connect(self, host=None, username=None, password=None, db_name=None):
|
|
"""
|
|
This is assumed to enable autocommit.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
|
|
def disconnect(self):
|
|
if self._connection:
|
|
self._connection.close()
|
|
self._connection = None
|
|
self._cursor = None
|
|
|
|
|
|
def execute(self, query, parameters=None):
|
|
if parameters is None:
|
|
parameters = ()
|
|
self._cursor.execute(query, parameters)
|
|
self.rowcount = self._cursor.rowcount
|
|
return self._cursor.fetchall()
|
|
|
|
|
|
class _MySqlBackend(_GenericBackend):
|
|
def __init__(self):
|
|
import MySQLdb
|
|
super(_MySqlBackend, self).__init__(MySQLdb)
|
|
|
|
|
|
@staticmethod
|
|
def convert_boolean(boolean, conversion_dict):
|
|
'Convert booleans to integer strings'
|
|
return str(int(boolean))
|
|
|
|
|
|
def connect(self, host=None, username=None, password=None, db_name=None):
|
|
import MySQLdb.converters
|
|
convert_dict = MySQLdb.converters.conversions
|
|
convert_dict.setdefault(bool, self.convert_boolean)
|
|
|
|
self._connection = self._database_module.connect(
|
|
host=host, user=username, passwd=password, db=db_name,
|
|
conv=convert_dict)
|
|
self._connection.autocommit(True)
|
|
self._cursor = self._connection.cursor()
|
|
|
|
|
|
class _SqliteBackend(_GenericBackend):
|
|
def __init__(self):
|
|
try:
|
|
from pysqlite2 import dbapi2
|
|
except ImportError:
|
|
from sqlite3 import dbapi2
|
|
super(_SqliteBackend, self).__init__(dbapi2)
|
|
self._last_insert_id_re = re.compile(r'\sLAST_INSERT_ID\(\)',
|
|
re.IGNORECASE)
|
|
|
|
|
|
def connect(self, host=None, username=None, password=None, db_name=None):
|
|
self._connection = self._database_module.connect(db_name)
|
|
self._connection.isolation_level = None # enable autocommit
|
|
self._cursor = self._connection.cursor()
|
|
|
|
|
|
def execute(self, query, parameters=None):
|
|
# pysqlite2 uses paramstyle=qmark
|
|
# TODO: make this more sophisticated if necessary
|
|
query = query.replace('%s', '?')
|
|
# pysqlite2 can't handle parameters=None (it throws a nonsense
|
|
# exception)
|
|
if parameters is None:
|
|
parameters = ()
|
|
# sqlite3 doesn't support MySQL's LAST_INSERT_ID(). Instead it has
|
|
# something similar called LAST_INSERT_ROWID() that will do enough of
|
|
# what we want (for our non-concurrent unittest use case).
|
|
query = self._last_insert_id_re.sub(' LAST_INSERT_ROWID()', query)
|
|
return super(_SqliteBackend, self).execute(query, parameters)
|
|
|
|
|
|
class _DjangoBackend(_GenericBackend):
|
|
def __init__(self):
|
|
from django.db import backend, connection, transaction
|
|
import django.db as django_db
|
|
super(_DjangoBackend, self).__init__(django_db)
|
|
self._django_connection = connection
|
|
self._django_transaction = transaction
|
|
|
|
|
|
def connect(self, host=None, username=None, password=None, db_name=None):
|
|
self._connection = self._django_connection
|
|
self._cursor = self._connection.cursor()
|
|
|
|
|
|
def execute(self, query, parameters=None):
|
|
try:
|
|
return super(_DjangoBackend, self).execute(query,
|
|
parameters=parameters)
|
|
finally:
|
|
self._django_transaction.commit_unless_managed()
|
|
|
|
|
|
_BACKEND_MAP = {
|
|
'mysql': _MySqlBackend,
|
|
'sqlite': _SqliteBackend,
|
|
'django': _DjangoBackend,
|
|
}
|
|
|
|
|
|
class DatabaseConnection(object):
|
|
"""
|
|
Generic wrapper for a database connection. Supports both mysql and sqlite
|
|
backends.
|
|
|
|
Public attributes:
|
|
* reconnect_enabled: if True, when an OperationalError occurs the class will
|
|
try to reconnect to the database automatically.
|
|
* reconnect_delay_sec: seconds to wait before reconnecting
|
|
* max_reconnect_attempts: maximum number of time to try reconnecting before
|
|
giving up. Setting to RECONNECT_FOREVER removes the limit.
|
|
* rowcount - will hold cursor.rowcount after each call to execute().
|
|
* global_config_section - the section in which to find DB information. this
|
|
should be passed to the constructor, not set later, and may be None, in
|
|
which case information must be passed to connect().
|
|
* debug - if set True, all queries will be printed before being executed
|
|
"""
|
|
_DATABASE_ATTRIBUTES = ('db_type', 'host', 'username', 'password',
|
|
'db_name')
|
|
|
|
def __init__(self, global_config_section=None, debug=False):
|
|
self.global_config_section = global_config_section
|
|
self._backend = None
|
|
self.rowcount = None
|
|
self.debug = debug
|
|
|
|
# reconnect defaults
|
|
self.reconnect_enabled = True
|
|
self.reconnect_delay_sec = 20
|
|
self.max_reconnect_attempts = 10
|
|
|
|
self._read_options()
|
|
|
|
|
|
def _get_option(self, name, provided_value, use_afe_setting=False):
|
|
"""Get value of given option from global config.
|
|
|
|
@param name: Name of the config.
|
|
@param provided_value: Value being provided to override the one from
|
|
global config.
|
|
@param use_afe_setting: Force to use the settings in AFE, default is
|
|
False.
|
|
"""
|
|
# TODO(dshi): This function returns the option value depends on multiple
|
|
# conditions. The value of `provided_value` has highest priority, then
|
|
# the code checks if use_afe_setting is True, if that's the case, force
|
|
# to use settings in AUTOTEST_WEB. At last the value is retrieved from
|
|
# specified global config section.
|
|
# The logic is too complicated for a generic function named like
|
|
# _get_option. Ideally we want to make it clear from caller that it
|
|
# wants to get database credential from one of the 3 ways:
|
|
# 1. Use the credential from given config section
|
|
# 2. Use the credential from AUTOTEST_WEB section
|
|
# 3. Use the credential provided by caller.
|
|
if provided_value is not None:
|
|
return provided_value
|
|
section = ('AUTOTEST_WEB' if use_afe_setting else
|
|
self.global_config_section)
|
|
if section:
|
|
global_config_name = _GLOBAL_CONFIG_NAMES.get(name, name)
|
|
return global_config.global_config.get_config_value(
|
|
section, global_config_name)
|
|
|
|
return getattr(self, name, None)
|
|
|
|
|
|
def _read_options(self, db_type=None, host=None, username=None,
|
|
password=None, db_name=None):
|
|
"""Read database information from global config.
|
|
|
|
Unless any parameter is specified a value, the connection will use
|
|
database name from given configure section (self.global_config_section),
|
|
and database credential from AFE database settings (AUTOTEST_WEB).
|
|
|
|
@param db_type: database type, default to None.
|
|
@param host: database hostname, default to None.
|
|
@param username: user name for database connection, default to None.
|
|
@param password: database password, default to None.
|
|
@param db_name: database name, default to None.
|
|
"""
|
|
self.db_name = self._get_option('db_name', db_name)
|
|
use_afe_setting = not bool(db_type or host or username or password)
|
|
|
|
# Database credential can be provided by the caller, as passed in from
|
|
# function connect.
|
|
self.db_type = self._get_option('db_type', db_type, use_afe_setting)
|
|
self.host = self._get_option('host', host, use_afe_setting)
|
|
self.username = self._get_option('username', username, use_afe_setting)
|
|
self.password = self._get_option('password', password, use_afe_setting)
|
|
|
|
|
|
def _get_backend(self, db_type):
|
|
if db_type not in _BACKEND_MAP:
|
|
raise ValueError('Invalid database type: %s, should be one of %s' %
|
|
(db_type, ', '.join(_BACKEND_MAP.keys())))
|
|
backend_class = _BACKEND_MAP[db_type]
|
|
return backend_class()
|
|
|
|
|
|
def _reached_max_attempts(self, num_attempts):
|
|
return (self.max_reconnect_attempts is not RECONNECT_FOREVER and
|
|
num_attempts > self.max_reconnect_attempts)
|
|
|
|
|
|
def _is_reconnect_enabled(self, supplied_param):
|
|
if supplied_param is not None:
|
|
return supplied_param
|
|
return self.reconnect_enabled
|
|
|
|
|
|
def _connect_backend(self, try_reconnecting=None):
|
|
num_attempts = 0
|
|
while True:
|
|
try:
|
|
self._backend.connect(host=self.host, username=self.username,
|
|
password=self.password,
|
|
db_name=self.db_name)
|
|
return
|
|
except self._backend.OperationalError:
|
|
num_attempts += 1
|
|
if not self._is_reconnect_enabled(try_reconnecting):
|
|
raise
|
|
if self._reached_max_attempts(num_attempts):
|
|
raise
|
|
traceback.print_exc()
|
|
print ("Can't connect to database; reconnecting in %s sec" %
|
|
self.reconnect_delay_sec)
|
|
time.sleep(self.reconnect_delay_sec)
|
|
self.disconnect()
|
|
|
|
|
|
def connect(self, db_type=None, host=None, username=None, password=None,
|
|
db_name=None, try_reconnecting=None):
|
|
"""
|
|
Parameters passed to this function will override defaults from global
|
|
config. try_reconnecting, if passed, will override
|
|
self.reconnect_enabled.
|
|
"""
|
|
self.disconnect()
|
|
self._read_options(db_type, host, username, password, db_name)
|
|
|
|
self._backend = self._get_backend(self.db_type)
|
|
_copy_exceptions(self._backend, self)
|
|
self._connect_backend(try_reconnecting)
|
|
|
|
|
|
def disconnect(self):
|
|
if self._backend:
|
|
self._backend.disconnect()
|
|
|
|
|
|
def execute(self, query, parameters=None, try_reconnecting=None):
|
|
"""
|
|
Execute a query and return cursor.fetchall(). try_reconnecting, if
|
|
passed, will override self.reconnect_enabled.
|
|
"""
|
|
if self.debug:
|
|
print 'Executing %s, %s' % (query, parameters)
|
|
# _connect_backend() contains a retry loop, so don't loop here
|
|
try:
|
|
results = self._backend.execute(query, parameters)
|
|
except self._backend.OperationalError:
|
|
if not self._is_reconnect_enabled(try_reconnecting):
|
|
raise
|
|
traceback.print_exc()
|
|
print ("MYSQL connection died; reconnecting")
|
|
self.disconnect()
|
|
self._connect_backend(try_reconnecting)
|
|
results = self._backend.execute(query, parameters)
|
|
|
|
self.rowcount = self._backend.rowcount
|
|
return results
|
|
|
|
|
|
def get_database_info(self):
|
|
return dict((attribute, getattr(self, attribute))
|
|
for attribute in self._DATABASE_ATTRIBUTES)
|
|
|
|
|
|
@classmethod
|
|
def get_test_database(cls, file_path=':memory:', **constructor_kwargs):
|
|
"""
|
|
Factory method returning a DatabaseConnection for a temporary in-memory
|
|
database.
|
|
"""
|
|
database = cls(**constructor_kwargs)
|
|
database.reconnect_enabled = False
|
|
database.connect(db_type='sqlite', db_name=file_path)
|
|
return database
|
|
|
|
|
|
class TranslatingDatabase(DatabaseConnection):
|
|
"""
|
|
Database wrapper than applies arbitrary substitution regexps to each query
|
|
string. Useful for SQLite testing.
|
|
"""
|
|
def __init__(self, translators):
|
|
"""
|
|
@param translation_regexps: list of callables to apply to each query
|
|
string (in order). Each accepts a query string and returns a
|
|
(possibly) modified query string.
|
|
"""
|
|
super(TranslatingDatabase, self).__init__()
|
|
self._translators = translators
|
|
|
|
|
|
def execute(self, query, parameters=None, try_reconnecting=None):
|
|
for translator in self._translators:
|
|
query = translator(query)
|
|
return super(TranslatingDatabase, self).execute(
|
|
query, parameters=parameters, try_reconnecting=try_reconnecting)
|
|
|
|
|
|
@classmethod
|
|
def make_regexp_translator(cls, search_re, replace_str):
|
|
"""
|
|
Returns a translator that calls re.sub() on the query with the given
|
|
search and replace arguments.
|
|
"""
|
|
def translator(query):
|
|
return re.sub(search_re, replace_str, query)
|
|
return translator
|