You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
484 lines
16 KiB
484 lines
16 KiB
#!/usr/bin/python2 -u
|
|
|
|
import os, sys, re, tempfile
|
|
from optparse import OptionParser
|
|
import common
|
|
from autotest_lib.client.common_lib import utils
|
|
from autotest_lib.database import database_connection
|
|
|
|
MIGRATE_TABLE = 'migrate_info'
|
|
|
|
_AUTODIR = os.path.join(os.path.dirname(__file__), '..')
|
|
_MIGRATIONS_DIRS = {
|
|
'AUTOTEST_WEB': os.path.join(_AUTODIR, 'frontend', 'migrations'),
|
|
'TKO': os.path.join(_AUTODIR, 'tko', 'migrations'),
|
|
'AUTOTEST_SERVER_DB': os.path.join(_AUTODIR, 'database',
|
|
'server_db_migrations'),
|
|
}
|
|
_DEFAULT_MIGRATIONS_DIR = 'migrations' # use CWD
|
|
|
|
class Migration(object):
|
|
"""Represents a database migration."""
|
|
_UP_ATTRIBUTES = ('migrate_up', 'UP_SQL')
|
|
_DOWN_ATTRIBUTES = ('migrate_down', 'DOWN_SQL')
|
|
|
|
def __init__(self, name, version, module):
|
|
self.name = name
|
|
self.version = version
|
|
self.module = module
|
|
self._check_attributes(self._UP_ATTRIBUTES)
|
|
self._check_attributes(self._DOWN_ATTRIBUTES)
|
|
|
|
|
|
@classmethod
|
|
def from_file(cls, filename):
|
|
"""Instantiates a Migration from a file.
|
|
|
|
@param filename: Name of a migration file.
|
|
|
|
@return An instantiated Migration object.
|
|
|
|
"""
|
|
version = int(filename[:3])
|
|
name = filename[:-3]
|
|
module = __import__(name, globals(), locals(), [])
|
|
return cls(name, version, module)
|
|
|
|
|
|
def _check_attributes(self, attributes):
|
|
method_name, sql_name = attributes
|
|
assert (hasattr(self.module, method_name) or
|
|
hasattr(self.module, sql_name))
|
|
|
|
|
|
def _execute_migration(self, attributes, manager):
|
|
method_name, sql_name = attributes
|
|
method = getattr(self.module, method_name, None)
|
|
if method:
|
|
assert callable(method)
|
|
method(manager)
|
|
else:
|
|
sql = getattr(self.module, sql_name)
|
|
assert isinstance(sql, basestring)
|
|
manager.execute_script(sql)
|
|
|
|
|
|
def migrate_up(self, manager):
|
|
"""Performs an up migration (to a newer version).
|
|
|
|
@param manager: A MigrationManager object.
|
|
|
|
"""
|
|
self._execute_migration(self._UP_ATTRIBUTES, manager)
|
|
|
|
|
|
def migrate_down(self, manager):
|
|
"""Performs a down migration (to an older version).
|
|
|
|
@param manager: A MigrationManager object.
|
|
|
|
"""
|
|
self._execute_migration(self._DOWN_ATTRIBUTES, manager)
|
|
|
|
|
|
class MigrationManager(object):
|
|
"""Managest database migrations."""
|
|
connection = None
|
|
cursor = None
|
|
migrations_dir = None
|
|
|
|
def __init__(self, database_connection, migrations_dir=None, force=False):
|
|
self._database = database_connection
|
|
self.force = force
|
|
# A boolean, this will only be set to True if this migration should be
|
|
# simulated rather than actually taken. For use with migrations that
|
|
# may make destructive queries
|
|
self.simulate = False
|
|
self._set_migrations_dir(migrations_dir)
|
|
|
|
|
|
def _set_migrations_dir(self, migrations_dir=None):
|
|
config_section = self._config_section()
|
|
if migrations_dir is None:
|
|
migrations_dir = os.path.abspath(
|
|
_MIGRATIONS_DIRS.get(config_section, _DEFAULT_MIGRATIONS_DIR))
|
|
self.migrations_dir = migrations_dir
|
|
sys.path.append(migrations_dir)
|
|
assert os.path.exists(migrations_dir), migrations_dir + " doesn't exist"
|
|
|
|
|
|
def _config_section(self):
|
|
return self._database.global_config_section
|
|
|
|
|
|
def get_db_name(self):
|
|
"""Gets the database name."""
|
|
return self._database.get_database_info()['db_name']
|
|
|
|
|
|
def execute(self, query, *parameters):
|
|
"""Executes a database query.
|
|
|
|
@param query: The query to execute.
|
|
@param parameters: Associated parameters for the query.
|
|
|
|
@return The result of the query.
|
|
|
|
"""
|
|
return self._database.execute(query, parameters)
|
|
|
|
|
|
def execute_script(self, script):
|
|
"""Executes a set of database queries.
|
|
|
|
@param script: A string of semicolon-separated queries.
|
|
|
|
"""
|
|
sql_statements = [statement.strip()
|
|
for statement in script.split(';')
|
|
if statement.strip()]
|
|
for statement in sql_statements:
|
|
self.execute(statement)
|
|
|
|
|
|
def check_migrate_table_exists(self):
|
|
"""Checks whether the migration table exists."""
|
|
try:
|
|
self.execute("SELECT * FROM %s" % MIGRATE_TABLE)
|
|
return True
|
|
except self._database.DatabaseError, exc:
|
|
# we can't check for more specifics due to differences between DB
|
|
# backends (we can't even check for a subclass of DatabaseError)
|
|
return False
|
|
|
|
|
|
def create_migrate_table(self):
|
|
"""Creates the migration table."""
|
|
if not self.check_migrate_table_exists():
|
|
self.execute("CREATE TABLE %s (`version` integer)" %
|
|
MIGRATE_TABLE)
|
|
else:
|
|
self.execute("DELETE FROM %s" % MIGRATE_TABLE)
|
|
self.execute("INSERT INTO %s VALUES (0)" % MIGRATE_TABLE)
|
|
assert self._database.rowcount == 1
|
|
|
|
|
|
def set_db_version(self, version):
|
|
"""Sets the database version.
|
|
|
|
@param version: The version to which to set the database.
|
|
|
|
"""
|
|
assert isinstance(version, int)
|
|
self.execute("UPDATE %s SET version=%%s" % MIGRATE_TABLE,
|
|
version)
|
|
assert self._database.rowcount == 1
|
|
|
|
|
|
def get_db_version(self):
|
|
"""Gets the database version.
|
|
|
|
@return The database version.
|
|
|
|
"""
|
|
if not self.check_migrate_table_exists():
|
|
return 0
|
|
rows = self.execute("SELECT * FROM %s" % MIGRATE_TABLE)
|
|
if len(rows) == 0:
|
|
return 0
|
|
assert len(rows) == 1 and len(rows[0]) == 1
|
|
return rows[0][0]
|
|
|
|
|
|
def get_migrations(self, minimum_version=None, maximum_version=None):
|
|
"""Gets the list of migrations to perform.
|
|
|
|
@param minimum_version: The minimum database version.
|
|
@param maximum_version: The maximum database version.
|
|
|
|
@return A list of Migration objects.
|
|
|
|
"""
|
|
migrate_files = [filename for filename
|
|
in os.listdir(self.migrations_dir)
|
|
if re.match(r'^\d\d\d_.*\.py$', filename)]
|
|
migrate_files.sort()
|
|
migrations = [Migration.from_file(filename)
|
|
for filename in migrate_files]
|
|
if minimum_version is not None:
|
|
migrations = [migration for migration in migrations
|
|
if migration.version >= minimum_version]
|
|
if maximum_version is not None:
|
|
migrations = [migration for migration in migrations
|
|
if migration.version <= maximum_version]
|
|
return migrations
|
|
|
|
|
|
def do_migration(self, migration, migrate_up=True):
|
|
"""Performs a migration.
|
|
|
|
@param migration: The Migration to perform.
|
|
@param migrate_up: Whether to migrate up (if not, then migrates down).
|
|
|
|
"""
|
|
print 'Applying migration %s' % migration.name, # no newline
|
|
if migrate_up:
|
|
print 'up'
|
|
assert self.get_db_version() == migration.version - 1
|
|
migration.migrate_up(self)
|
|
new_version = migration.version
|
|
else:
|
|
print 'down'
|
|
assert self.get_db_version() == migration.version
|
|
migration.migrate_down(self)
|
|
new_version = migration.version - 1
|
|
self.set_db_version(new_version)
|
|
|
|
|
|
def migrate_to_version(self, version):
|
|
"""Performs a migration to a specified version.
|
|
|
|
@param version: The version to which to migrate the database.
|
|
|
|
"""
|
|
current_version = self.get_db_version()
|
|
if current_version == 0 and self._config_section() == 'AUTOTEST_WEB':
|
|
self._migrate_from_base()
|
|
current_version = self.get_db_version()
|
|
|
|
if current_version < version:
|
|
lower, upper = current_version, version
|
|
migrate_up = True
|
|
else:
|
|
lower, upper = version, current_version
|
|
migrate_up = False
|
|
|
|
migrations = self.get_migrations(lower + 1, upper)
|
|
if not migrate_up:
|
|
migrations.reverse()
|
|
for migration in migrations:
|
|
self.do_migration(migration, migrate_up)
|
|
|
|
assert self.get_db_version() == version
|
|
print 'At version', version
|
|
|
|
|
|
def _migrate_from_base(self):
|
|
"""Initialize the AFE database.
|
|
"""
|
|
self.confirm_initialization()
|
|
|
|
migration_script = utils.read_file(
|
|
os.path.join(os.path.dirname(__file__), 'schema_129.sql'))
|
|
migration_script = migration_script % (
|
|
dict(username=self._database.get_database_info()['username']))
|
|
self.execute_script(migration_script)
|
|
|
|
self.create_migrate_table()
|
|
self.set_db_version(129)
|
|
|
|
|
|
def confirm_initialization(self):
|
|
"""Confirms with the user that we should initialize the database.
|
|
|
|
@raises Exception, if the user chooses to abort the migration.
|
|
|
|
"""
|
|
if not self.force:
|
|
response = raw_input(
|
|
'Your %s database does not appear to be initialized. Do you '
|
|
'want to recreate it (this will result in loss of any existing '
|
|
'data) (yes/No)? ' % self.get_db_name())
|
|
if response != 'yes':
|
|
raise Exception('User has chosen to abort migration')
|
|
|
|
|
|
def get_latest_version(self):
|
|
"""Gets the latest database version."""
|
|
migrations = self.get_migrations()
|
|
return migrations[-1].version
|
|
|
|
|
|
def migrate_to_latest(self):
|
|
"""Migrates the database to the latest version."""
|
|
latest_version = self.get_latest_version()
|
|
self.migrate_to_version(latest_version)
|
|
|
|
|
|
def initialize_test_db(self):
|
|
"""Initializes a test database."""
|
|
db_name = self.get_db_name()
|
|
test_db_name = 'test_' + db_name
|
|
# first, connect to no DB so we can create a test DB
|
|
self._database.connect(db_name='')
|
|
print 'Creating test DB', test_db_name
|
|
self.execute('CREATE DATABASE ' + test_db_name)
|
|
self._database.disconnect()
|
|
# now connect to the test DB
|
|
self._database.connect(db_name=test_db_name)
|
|
|
|
|
|
def remove_test_db(self):
|
|
"""Removes a test database."""
|
|
print 'Removing test DB'
|
|
self.execute('DROP DATABASE ' + self.get_db_name())
|
|
# reset connection back to real DB
|
|
self._database.disconnect()
|
|
self._database.connect()
|
|
|
|
|
|
def get_mysql_args(self):
|
|
"""Returns the mysql arguments as a string."""
|
|
return ('-u %(username)s -p%(password)s -h %(host)s %(db_name)s' %
|
|
self._database.get_database_info())
|
|
|
|
|
|
def migrate_to_version_or_latest(self, version):
|
|
"""Migrates to either a specified version, or the latest version.
|
|
|
|
@param version: The version to which to migrate the database,
|
|
or None in order to migrate to the latest version.
|
|
|
|
"""
|
|
if version is None:
|
|
self.migrate_to_latest()
|
|
else:
|
|
self.migrate_to_version(version)
|
|
|
|
|
|
def do_sync_db(self, version=None):
|
|
"""Migrates the database.
|
|
|
|
@param version: The version to which to migrate the database.
|
|
|
|
"""
|
|
print 'Migration starting for database', self.get_db_name()
|
|
self.migrate_to_version_or_latest(version)
|
|
print 'Migration complete'
|
|
|
|
|
|
def test_sync_db(self, version=None):
|
|
"""Create a fresh database and run all migrations on it.
|
|
|
|
@param version: The version to which to migrate the database.
|
|
|
|
"""
|
|
self.initialize_test_db()
|
|
try:
|
|
print 'Starting migration test on DB', self.get_db_name()
|
|
self.migrate_to_version_or_latest(version)
|
|
# show schema to the user
|
|
os.system('mysqldump %s --no-data=true '
|
|
'--add-drop-table=false' %
|
|
self.get_mysql_args())
|
|
finally:
|
|
self.remove_test_db()
|
|
print 'Test finished successfully'
|
|
|
|
|
|
def simulate_sync_db(self, version=None):
|
|
"""Creates a fresh DB, copies existing DB to it, then synchronizes it.
|
|
|
|
@param version: The version to which to migrate the database.
|
|
|
|
"""
|
|
db_version = self.get_db_version()
|
|
# don't do anything if we're already at the latest version
|
|
if db_version == self.get_latest_version():
|
|
print 'Skipping simulation, already at latest version'
|
|
return
|
|
# get existing data
|
|
self.initialize_and_fill_test_db()
|
|
try:
|
|
print 'Starting migration test on DB', self.get_db_name()
|
|
self.migrate_to_version_or_latest(version)
|
|
finally:
|
|
self.remove_test_db()
|
|
print 'Test finished successfully'
|
|
|
|
|
|
def initialize_and_fill_test_db(self):
|
|
"""Initializes and fills up a test database."""
|
|
print 'Dumping existing data'
|
|
dump_fd, dump_file = tempfile.mkstemp('.migrate_dump')
|
|
os.system('mysqldump %s >%s' %
|
|
(self.get_mysql_args(), dump_file))
|
|
# fill in test DB
|
|
self.initialize_test_db()
|
|
print 'Filling in test DB'
|
|
os.system('mysql %s <%s' % (self.get_mysql_args(), dump_file))
|
|
os.close(dump_fd)
|
|
os.remove(dump_file)
|
|
|
|
|
|
USAGE = """\
|
|
%s [options] sync|test|simulate|safesync [version]
|
|
Options:
|
|
-d --database Which database to act on
|
|
-f --force Don't ask for confirmation
|
|
--debug Print all DB queries"""\
|
|
% sys.argv[0]
|
|
|
|
|
|
def main():
|
|
"""Main function for the migration script."""
|
|
parser = OptionParser()
|
|
parser.add_option("-d", "--database",
|
|
help="which database to act on",
|
|
dest="database",
|
|
default="AUTOTEST_WEB")
|
|
parser.add_option("-f", "--force", help="don't ask for confirmation",
|
|
action="store_true")
|
|
parser.add_option('--debug', help='print all DB queries',
|
|
action='store_true')
|
|
(options, args) = parser.parse_args()
|
|
manager = get_migration_manager(db_name=options.database,
|
|
debug=options.debug, force=options.force)
|
|
|
|
if len(args) > 0:
|
|
if len(args) > 1:
|
|
version = int(args[1])
|
|
else:
|
|
version = None
|
|
if args[0] == 'sync':
|
|
manager.do_sync_db(version)
|
|
elif args[0] == 'test':
|
|
manager.simulate=True
|
|
manager.test_sync_db(version)
|
|
elif args[0] == 'simulate':
|
|
manager.simulate=True
|
|
manager.simulate_sync_db(version)
|
|
elif args[0] == 'safesync':
|
|
print 'Simluating migration'
|
|
manager.simulate=True
|
|
manager.simulate_sync_db(version)
|
|
print 'Performing real migration'
|
|
manager.simulate=False
|
|
manager.do_sync_db(version)
|
|
else:
|
|
print USAGE
|
|
return
|
|
|
|
print USAGE
|
|
|
|
|
|
def get_migration_manager(db_name, debug, force):
|
|
"""Creates a MigrationManager object.
|
|
|
|
@param db_name: The database name.
|
|
@param debug: Whether to print debug messages.
|
|
@param force: Whether to force migration without asking for confirmation.
|
|
|
|
@return A created MigrationManager object.
|
|
|
|
"""
|
|
database = database_connection.DatabaseConnection(db_name)
|
|
database.debug = debug
|
|
database.reconnect_enabled = False
|
|
database.connect()
|
|
return MigrationManager(database, force=force)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|