You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
164 lines
4.5 KiB
164 lines
4.5 KiB
#!/usr/bin/python2
|
|
|
|
import unittest, tempfile, os
|
|
import common
|
|
import MySQLdb
|
|
from autotest_lib.client.common_lib import global_config
|
|
from autotest_lib.database import database_connection, migrate
|
|
|
|
# Which section of the global config to pull info from. We won't actually use
|
|
# that DB, we'll use the corresponding test DB (test_<db name>).
|
|
CONFIG_DB = 'AUTOTEST_WEB'
|
|
|
|
NUM_MIGRATIONS = 3
|
|
|
|
class DummyMigration(object):
|
|
"""\
|
|
Dummy migration class that records all migrations done in a class
|
|
varaible.
|
|
"""
|
|
|
|
migrations_done = []
|
|
|
|
def __init__(self, version):
|
|
self.version = version
|
|
self.name = '%03d_test' % version
|
|
|
|
|
|
@classmethod
|
|
def get_migrations_done(cls):
|
|
return cls.migrations_done
|
|
|
|
|
|
@classmethod
|
|
def clear_migrations_done(cls):
|
|
cls.migrations_done = []
|
|
|
|
|
|
@classmethod
|
|
def do_migration(cls, version, direction):
|
|
cls.migrations_done.append((version, direction))
|
|
|
|
|
|
def migrate_up(self, manager):
|
|
self.do_migration(self.version, 'up')
|
|
if self.version == 1:
|
|
manager.create_migrate_table()
|
|
|
|
|
|
def migrate_down(self, manager):
|
|
self.do_migration(self.version, 'down')
|
|
|
|
|
|
MIGRATIONS = [DummyMigration(n) for n in xrange(1, NUM_MIGRATIONS + 1)]
|
|
|
|
|
|
class TestableMigrationManager(migrate.MigrationManager):
|
|
def _set_migrations_dir(self, migrations_dir=None):
|
|
pass
|
|
|
|
|
|
def get_migrations(self, minimum_version=None, maximum_version=None):
|
|
minimum_version = minimum_version or 1
|
|
maximum_version = maximum_version or len(MIGRATIONS)
|
|
return MIGRATIONS[minimum_version-1:maximum_version]
|
|
|
|
|
|
class MigrateManagerTest(unittest.TestCase):
|
|
def setUp(self):
|
|
self._database = (
|
|
database_connection.DatabaseConnection.get_test_database())
|
|
self._database.connect()
|
|
self.manager = TestableMigrationManager(self._database)
|
|
DummyMigration.clear_migrations_done()
|
|
|
|
|
|
def tearDown(self):
|
|
self._database.disconnect()
|
|
|
|
|
|
def test_sync(self):
|
|
self.manager.do_sync_db()
|
|
self.assertEquals(self.manager.get_db_version(), NUM_MIGRATIONS)
|
|
self.assertEquals(DummyMigration.get_migrations_done(),
|
|
[(1, 'up'), (2, 'up'), (3, 'up')])
|
|
|
|
DummyMigration.clear_migrations_done()
|
|
self.manager.do_sync_db(0)
|
|
self.assertEquals(self.manager.get_db_version(), 0)
|
|
self.assertEquals(DummyMigration.get_migrations_done(),
|
|
[(3, 'down'), (2, 'down'), (1, 'down')])
|
|
|
|
|
|
def test_sync_one_by_one(self):
|
|
for version in xrange(1, NUM_MIGRATIONS + 1):
|
|
self.manager.do_sync_db(version)
|
|
self.assertEquals(self.manager.get_db_version(),
|
|
version)
|
|
self.assertEquals(
|
|
DummyMigration.get_migrations_done()[-1],
|
|
(version, 'up'))
|
|
|
|
for version in xrange(NUM_MIGRATIONS - 1, -1, -1):
|
|
self.manager.do_sync_db(version)
|
|
self.assertEquals(self.manager.get_db_version(),
|
|
version)
|
|
self.assertEquals(
|
|
DummyMigration.get_migrations_done()[-1],
|
|
(version + 1, 'down'))
|
|
|
|
|
|
def test_null_sync(self):
|
|
self.manager.do_sync_db()
|
|
DummyMigration.clear_migrations_done()
|
|
self.manager.do_sync_db()
|
|
self.assertEquals(DummyMigration.get_migrations_done(), [])
|
|
|
|
|
|
class DummyMigrationManager(object):
|
|
def __init__(self):
|
|
self.calls = []
|
|
|
|
|
|
def execute_script(self, script):
|
|
self.calls.append(script)
|
|
|
|
|
|
class MigrationTest(unittest.TestCase):
|
|
def setUp(self):
|
|
self.manager = DummyMigrationManager()
|
|
|
|
|
|
def _do_migration(self, migration_module):
|
|
migration = migrate.Migration('name', 1, migration_module)
|
|
migration.migrate_up(self.manager)
|
|
migration.migrate_down(self.manager)
|
|
|
|
self.assertEquals(self.manager.calls, ['foo', 'bar'])
|
|
|
|
|
|
def test_migration_with_methods(self):
|
|
class DummyMigration(object):
|
|
@staticmethod
|
|
def migrate_up(manager):
|
|
manager.execute_script('foo')
|
|
|
|
|
|
@staticmethod
|
|
def migrate_down(manager):
|
|
manager.execute_script('bar')
|
|
|
|
self._do_migration(DummyMigration)
|
|
|
|
|
|
def test_migration_with_strings(self):
|
|
class DummyMigration(object):
|
|
UP_SQL = 'foo'
|
|
DOWN_SQL = 'bar'
|
|
|
|
self._do_migration(DummyMigration)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|