"""Migration router.""" from __future__ import annotations import os import pkgutil import re import sys from functools import cached_property from importlib import import_module from pathlib import Path from types import ModuleType from typing import TYPE_CHECKING, Any, Final, Iterable, List, Optional, Set, Type, Union from unittest import mock import peewee as pw from .auto import NEWLINE, diff_many from .logs import logger from .migrator import Migrator from .models import MIGRATE_TABLE, MigrateHistory from .template import TEMPLATE if TYPE_CHECKING: from logging import Logger CLEAN_RE: Final = re.compile(r"\s+$", re.M) CURDIR: Final = Path.cwd() DEFAULT_MIGRATE_DIR: Final = CURDIR / "migrations" def void(m, d, fake=None): return None class BaseRouter(object): """Abstract base class for router.""" def __init__( # noqa: self, database: Union[pw.Database, pw.Proxy], migrate_table=MIGRATE_TABLE, ignore: Optional[Iterable[str]] = None, schema: Optional[str] = None, logger: Logger = logger, ): """Initialize the router.""" self.database = database self.migrate_table = migrate_table self.schema = schema self.ignore = ignore self.logger = logger if not isinstance(self.database, (pw.Database, pw.Proxy)): raise TypeError("Invalid database: %s" % database) @cached_property def model(self) -> Type[MigrateHistory]: """Initialize and cache MigrationHistory model.""" meta = MigrateHistory._meta # type: ignore[] meta.database = self.database meta.table_name = self.migrate_table meta.schema = self.schema MigrateHistory.create_table(safe=True) return MigrateHistory @property def todo(self) -> Iterable[str]: """Get migrations to run.""" raise NotImplementedError @property def done(self) -> List[str]: """Scan migrations in database.""" return [mm.name for mm in self.model.select().order_by(self.model.id)] @property def diff(self) -> List[str]: """Calculate difference between fs and db.""" done = set(self.done) return [name for name in self.todo if name not in done] @cached_property def migrator(self) -> Migrator: """Create migrator and setup it with fake migrations.""" migrator = Migrator(self.database) for name in self.done: self.run_one(name, migrator) return migrator def create(self, name: str = "auto", *, auto: Any = False) -> Optional[str]: """Create a migration. :param auto: Python module path to scan for models. """ migrate = rollback = "" if auto: # Need to append the CURDIR to the path for import to work. sys.path.append(f"{ CURDIR }") models = auto if isinstance(auto, list) else [auto] if not all(_check_model(m) for m in models): try: modules = models if isinstance(auto, bool): modules = [ m for _, m, ispkg in pkgutil.iter_modules([f"{CURDIR}"]) if ispkg ] models = [m for module in modules for m in load_models(module)] except ImportError: self.logger.exception("Can't import models module: %s", auto) return None if self.ignore: models = [m for m in models if m._meta.name not in self.ignore] # type: ignore[] for migration in self.diff: self.run_one(migration, self.migrator, fake=True) migrate = compile_migrations(self.migrator, models) if not migrate: self.logger.warning("No changes found.") return None rollback = compile_migrations(self.migrator, models, reverse=True) self.logger.info('Creating migration "%s"', name) name = self.compile(name, migrate, rollback) self.logger.info('Migration has been created as "%s"', name) return name def merge(self, name: str = "initial"): """Merge migrations into one.""" migrator = Migrator(self.database) migrate = compile_migrations(migrator, list(self.migrator.orm)) if not migrate: return self.logger.error("Can't merge migrations") self.clear() self.logger.info('Merge migrations into "%s"', name) rollback = compile_migrations(self.migrator, []) name = self.compile(name, migrate, rollback, 0) migrator = Migrator(self.database) self.run_one(name, migrator, fake=True, force=True) self.logger.info('Migrations has been merged into "%s"', name) return None def clear(self): """Clear migrations.""" self.model.delete().execute() def compile( # noqa: self, name: str, migrate: str = "", rollback: str = "", num: Optional[int] = None, ) -> str: """Create a migration.""" raise NotImplementedError def read(self, name: str): """Read migration from file.""" raise NotImplementedError def run_one( self, name: str, migrator: Migrator, *, fake: bool = True, downgrade: bool = False, force: bool = False, ) -> str: """Run/emulate a migration with given name.""" try: migrate, rollback = self.read(name) if fake: mocked_cursor = mock.Mock() mocked_cursor.fetch_one.return_value = None with mock.patch("peewee.Model.select"), mock.patch( "peewee.Database.execute_sql", return_value=mocked_cursor ): migrate(migrator, self.database, fake=fake) if force: self.model.create(name=name) self.logger.info("Done %s", name) migrator.__ops__ = [] return name with self.database.transaction(): if not downgrade: self.logger.info('Migrate "%s"', name) migrate(migrator, self.database, fake=fake) migrator() self.model.create(name=name) else: self.logger.info("Rolling back %s", name) rollback(migrator, self.database, fake=fake) migrator() self.model.delete().where(self.model.name == name).execute() self.logger.info("Done %s", name) return name except Exception: self.database.rollback() operation = "Migration" if not downgrade else "Rollback" self.logger.exception("%s failed: %s", operation, name) raise def run(self, name: Optional[str] = None, *, fake: bool = False) -> List[str]: """Run migrations.""" self.logger.info("Starting migrations") done: List[str] = [] diff = self.diff if not diff: self.logger.info("There is nothing to migrate") return done migrator = self.migrator for mname in diff: done.append(self.run_one(mname, migrator, fake=fake, force=fake)) if name and name == mname: break return done def rollback(self): """Rollback the latest migration.""" done = self.done if not done: msg = "There is nothing to rollback" raise RuntimeError(msg) name = done[-1] migrator = self.migrator self.run_one(name, migrator, fake=False, downgrade=True) self.logger.warning("Downgraded migration: %s", name) class Router(BaseRouter): """File system router.""" filemask = re.compile(r"[\d]{3}_[^\.]+\.py$") def __init__( self, database, migrate_dir: Optional[Union[str, Path]] = None, **kwargs, ): """Initialize the router.""" super(Router, self).__init__(database, **kwargs) if migrate_dir is None: migrate_dir = DEFAULT_MIGRATE_DIR elif isinstance(migrate_dir, str): migrate_dir = Path(migrate_dir) self.migrate_dir = migrate_dir @property def todo(self): """Scan migrations in file system.""" if not self.migrate_dir.exists(): self.logger.warning("Migration directory: %s does not exist.", self.migrate_dir) self.migrate_dir.mkdir(parents=True) return sorted(f[:-3] for f in os.listdir(self.migrate_dir) if self.filemask.match(f)) def compile(self, name, migrate="", rollback="", num=None) -> str: # noqa: """Create a migration.""" if num is None: num = len(self.todo) name = "{:03}_".format(num + 1) + name filename = name + ".py" path = self.migrate_dir / filename with path.open("w") as f: f.write(TEMPLATE.format(migrate=migrate, rollback=rollback, name=filename)) return name def read(self, name): """Read migration from file.""" path = self.migrate_dir / (name + ".py") with path.open("r") as f: code = f.read() scope = {} code = compile(code, "", "exec", dont_inherit=True) exec(code, scope, None) return scope.get("migrate", void), scope.get("rollback", void) def clear(self): """Remove migrations from fs.""" super(Router, self).clear() for name in self.todo: path = self.migrate_dir / (name + ".py") path.unlink() class ModuleRouter(BaseRouter): """Module based router.""" def __init__(self, database, migrate_module="migrations", **kwargs): """Initialize the router.""" super(ModuleRouter, self).__init__(database, **kwargs) if isinstance(migrate_module, str): migrate_module = import_module(migrate_module) self.migrate_module = migrate_module def read(self, name): """Read migrations from a module.""" mod = getattr(self.migrate_module, name) return getattr(mod, "migrate", void), getattr(mod, "rollback", void) def load_models(module: Union[str, ModuleType]) -> Set[Type[pw.Model]]: """Load models from given module.""" modules = [module] if isinstance(module, ModuleType) else _import_submodules(module) return { m for module in modules for m in filter(_check_model, (getattr(module, name) for name in dir(module))) } def _import_submodules(package, passed=...): if passed is ...: passed = set() if isinstance(package, str): package = import_module(package) # https://github.com/klen/peewee_migrate/issues/125 if not hasattr(package, "__path__"): return {package} modules = [] if set(package.__path__) & passed: return modules passed |= set(package.__path__) for loader, name, is_pkg in pkgutil.walk_packages(package.__path__, package.__name__ + "."): spec = loader.find_spec(name, None) if spec is None or spec.loader is None: continue module = spec.loader.load_module(name) modules.append(module) if is_pkg: modules += _import_submodules(module) return modules def _check_model(obj): """Check object if it's a peewee model and unique.""" return isinstance(obj, type) and issubclass(obj, pw.Model) and hasattr(obj, "_meta") def compile_migrations(migrator: Migrator, models, *, reverse=False): """Compile migrations for given models.""" source = list(migrator.orm) if reverse: source, models = models, source migrations = diff_many(models, source, migrator, reverse=reverse) if not migrations: return "" code = NEWLINE + NEWLINE.join("\n\n".join(migrations).split("\n")) return CLEAN_RE.sub("\n", code)