"""Run migrations.""" from __future__ import annotations from typing import ( TYPE_CHECKING, Any, Callable, List, Union, cast, overload, ) import peewee as pw from playhouse.migrate import ( SQL, Context, MySQLDatabase, Operation, PostgresqlDatabase, SqliteDatabase, make_index_name, operation, ) from playhouse.migrate import MySQLMigrator as MqM from playhouse.migrate import PostgresqlMigrator as PgM from playhouse.migrate import SchemaMigrator as ScM from playhouse.migrate import SqliteMigrator as SqM from .logs import logger if TYPE_CHECKING: from .types import TModelType, TVModelType class SchemaMigrator(ScM): """Implement migrations.""" @classmethod def from_database(cls, database: Union[pw.Database, pw.Proxy]) -> SchemaMigrator: """Initialize migrator by db.""" if isinstance(database, PostgresqlDatabase): return PostgresqlMigrator(database) if isinstance(database, SqliteDatabase): return SqliteMigrator(database) if isinstance(database, MySQLDatabase): return MySQLMigrator(database) raise ValueError("Unsupported database: %s" % database) def drop_table(self, model: TModelType, *, cascade: bool = True) -> Callable[[], Any]: """Drop table.""" return lambda: model.drop_table(cascade=cascade) @operation def change_column( self, table: str, column_name: str, field: pw.Field ) -> List[Union[Context, Operation]]: """Change column.""" operations: List[Union[Context, Operation]] = [ self.alter_change_column(table, column_name, field) ] if not field.null: operations.extend([self.add_not_null(table, column_name)]) return operations def alter_change_column(self, table: str, column: str, field: pw.Field) -> Context: """Support change columns.""" ctx = self.make_context() field_null, field.null = field.null, True ctx = self._alter_table(ctx, table).literal(" ALTER COLUMN ").sql(field.ddl(ctx)) field.null = field_null return ctx @operation def sql(self, sql: str, *params) -> SQL: """Execute raw SQL.""" return SQL(sql, *params) def alter_add_column( self, table: str, column_name: str, field: pw.Field, **kwargs ) -> Operation: """Fix fieldname for ForeignKeys.""" name = field.name op = super(SchemaMigrator, self).alter_add_column(table, column_name, field, **kwargs) if isinstance(field, pw.ForeignKeyField): field.name = name return op class MySQLMigrator(SchemaMigrator, MqM): """Support MySQL.""" def alter_change_column(self, table: str, column: str, field: pw.Field) -> Context: """Support change columns.""" ctx = self.make_context() field_null, field.null = field.null, True ctx = self._alter_table(ctx, table).literal(" MODIFY COLUMN ").sql(field.ddl(ctx)) field.null = field_null return ctx class PostgresqlMigrator(SchemaMigrator, PgM): """Support the migrations in postgresql.""" def alter_change_column(self, table: str, column_name: str, field: pw.Field) -> Context: """Support change columns.""" context = super(PostgresqlMigrator, self).alter_change_column(table, column_name, field) context._sql.insert(-1, "TYPE") # type: ignore[] context._sql.insert(-1, " ") # type: ignore[] return context class SqliteMigrator(SchemaMigrator, SqM): """Support the migrations in sqlite.""" def drop_table(self, model: pw.Model, *, cascade: bool = True) -> Callable: """Sqlite doesnt support cascade syntax by default.""" return lambda: model.drop_table(cascade=False) def alter_change_column(self, table: str, column: str, field: pw.Field) -> Operation: """Support change columns.""" def fn(c_name, c_def): ctx = self.make_context() ctx.sql(field.ddl(ctx)) return ctx.query()[0] return self._update_column(table, column, fn) # type: ignore[] class ORM: __slots__ = ("__tables__", "__models__") def __init__(self): self.__tables__ = {} self.__models__ = {} def add(self, model: TModelType): self.__models__[model.__name__] = model self.__tables__[model._meta.table_name] = model # type: ignore[] def remove(self, model: TModelType): del self.__models__[model.__name__] del self.__tables__[model._meta.table_name] # type: ignore[] def __getattr__(self, name: str) -> TModelType: return self.__models__[name] def __getitem__(self, name: str) -> TModelType: return self.__tables__[name] def __iter__(self): return iter(self.__models__.values()) class Migrator: """Provide migrations.""" def __init__(self, database: Union[pw.Database, pw.Proxy]): """Initialize the migrator.""" self.orm: ORM = ORM() if isinstance(database, pw.Proxy): database = database.obj self.__database__ = database self.__ops__: List[Union[Operation, Callable]] = [] self.__migrator__ = SchemaMigrator.from_database(database) def __call__(self): """Run operations.""" for op in self.__ops__: if isinstance(op, Operation): logger.info("%s %s", op.method, op.args) op.run() else: logger.info("Run %s", op.__name__) op() self.__ops__ = [] def __iter__(self): """Iterate over models.""" return iter(self.orm) @overload def __get_model__(self, model: TVModelType) -> TVModelType: ... @overload def __get_model__(self, model: str) -> TModelType: ... def __get_model__(self, model: Union[TVModelType, str]) -> Union[TVModelType, TModelType]: """Get model by name.""" if isinstance(model, str): if model in self.orm.__models__: return self.orm.__models__[model] if model in self.orm.__tables__: return self.orm[model] raise ValueError("Model %s not found" % model) return model def sql(self, sql: str, *params): """Execute raw SQL.""" op = cast(Operation, self.__migrator__.sql(sql, *params)) self.__ops__.append(op) def python(self, func: Callable, *args, **kwargs): """Run a python function.""" self.__ops__.append(lambda: func(*args, **kwargs)) def create_table(self, model: TVModelType) -> TVModelType: """Create model and table in database. >> migrator.create_table(model) """ meta = model._meta # type: ignore[] self.orm.add(model) meta.database = self.__database__ self.__ops__.append(model.create_table) return model create_model = create_table def drop_table(self, model: Union[str, TModelType], *, cascade: bool = True): """Drop model and table from database. >> migrator.drop_table(model, cascade=True) """ model = self.__get_model__(model) self.orm.remove(model) self.__ops__.append(self.__migrator__.drop_table(model, cascade=cascade)) remove_model = drop_table def add_columns(self, model: Union[str, TModelType], **fields: pw.Field) -> TModelType: """Create new fields.""" model = self.__get_model__(model) meta = model._meta # type: ignore[] for name, field in fields.items(): meta.add_field(name, field) self.__ops__.append( self.__migrator__.add_column( # type: ignore[] meta.table_name, field.column_name, field ) ) if field.unique: self.__ops__.append( self.__migrator__.add_index(meta.table_name, (field.column_name,), unique=True) ) return model add_fields = add_columns def change_columns(self, model: Union[str, TModelType], **fields: pw.Field) -> TModelType: """Change fields.""" model = self.__get_model__(model) meta = model._meta # type: ignore[] for name, field in fields.items(): old_field = meta.fields.get(name, field) old_column_name = old_field and old_field.column_name meta.add_field(name, field) if isinstance(old_field, pw.ForeignKeyField): self.__ops__.append( self.__migrator__.drop_foreign_key_constraint(meta.table_name, old_column_name) ) if old_column_name != field.column_name: self.__ops__.append( self.__migrator__.rename_column( meta.table_name, old_column_name, field.column_name ) ) if isinstance(field, pw.ForeignKeyField): on_delete = field.on_delete if field.on_delete else "RESTRICT" on_update = field.on_update if field.on_update else "RESTRICT" self.__ops__.append( self.__migrator__.add_foreign_key_constraint( meta.table_name, field.column_name, field.rel_model._meta.table_name, field.rel_field.name, on_delete, on_update, ) ) continue self.__ops__.append( self.__migrator__.change_column( # type: ignore[] meta.table_name, field.column_name, field ) ) if field.unique == old_field.unique: continue if field.unique: index = (field.column_name,), field.unique self.__ops__.append(self.__migrator__.add_index(meta.table_name, *index)) meta.indexes.append(index) else: index = field.column_name self.__ops__.append(self.__migrator__.drop_index(meta.table_name, index)) meta.indexes.remove(((field.column_name,), old_field.unique)) return model change_fields = change_columns def drop_columns( self, model: Union[str, TModelType], *names: str, cascade: bool = True ) -> TModelType: """Remove fields from model.""" model = self.__get_model__(model) meta = model._meta # type: ignore[] fields = [field for field in meta.fields.values() if field.name in names] for field in fields: self.__del_field__(model, field) if field.unique: index_name = make_index_name(meta.table_name, [field.column_name]) self.__ops__.append(self.__migrator__.drop_index(meta.table_name, index_name)) self.__ops__.append( self.__migrator__.drop_column( # type: ignore[] meta.table_name, field.column_name, cascade=cascade ) ) return model remove_fields = drop_columns def __del_field__(self, model: TModelType, field: pw.Field): """Delete field from model.""" meta = model._meta # type: ignore[] meta.remove_field(field.name) delattr(model, field.name) if isinstance(field, pw.ForeignKeyField): obj_id_name = field.column_name if field.column_name == field.name: obj_id_name += "_id" if hasattr(model, obj_id_name): delattr(model, obj_id_name) delattr(field.rel_model, field.backref) def rename_column( self, model: Union[str, TModelType], old_name: str, new_name: str ) -> TModelType: """Rename field in model.""" model = self.__get_model__(model) meta = model._meta # type: ignore[] field = meta.fields[old_name] if isinstance(field, pw.ForeignKeyField): old_name = field.column_name self.__del_field__(model, field) field.name = field.column_name = new_name if isinstance(field, pw.ForeignKeyField): field.column_name = field.column_name + "_id" meta.add_field(new_name, field) self.__ops__.append( self.__migrator__.rename_column(meta.table_name, old_name, field.column_name) ) return model rename_field = rename_column def rename_table(self, model: Union[str, TModelType], new_name: str) -> TModelType: """Rename table in database.""" model = self.__get_model__(model) meta = model._meta # type: ignore[] old_name = meta.table_name self.orm.remove(model) meta.table_name = new_name self.orm.add(model) self.__ops__.append(self.__migrator__.rename_table(old_name, new_name)) return model def add_index(self, model: Union[str, TModelType], *columns: str, unique=False) -> TModelType: """Create indexes.""" model = self.__get_model__(model) meta = model._meta # type: ignore[] meta.indexes.append((columns, unique)) columns_ = [] for col in columns: field = meta.fields.get(col) if len(columns) == 1: field.unique = unique field.index = not unique columns_.append(field.column_name) self.__ops__.append(self.__migrator__.add_index(meta.table_name, columns_, unique=unique)) return model def drop_index(self, model: Union[str, TModelType], *columns: str) -> TModelType: """Drop indexes.""" model = self.__get_model__(model) meta = model._meta # type: ignore[] columns_ = [] for col in columns: field = meta.fields.get(col) if not field: continue if len(columns) == 1: field.unique = field.index = False columns_.append(field.column_name) index_name = make_index_name(meta.table_name, columns_) meta.indexes = [(cols, _) for (cols, _) in meta.indexes if columns != cols] self.__ops__.append(self.__migrator__.drop_index(meta.table_name, index_name)) return model def add_not_null(self, model: Union[str, TModelType], *names: str) -> TModelType: """Add not null.""" model = self.__get_model__(model) meta = model._meta # type: ignore[] for name in names: field = meta.fields[name] field.null = False self.__ops__.append(self.__migrator__.add_not_null(meta.table_name, field.column_name)) return model def drop_not_null(self, model: Union[str, TModelType], *names: str) -> TModelType: """Drop not null.""" model = self.__get_model__(model) meta = model._meta # type: ignore[] for name in names: field = meta.fields[name] field.null = True self.__ops__.append(self.__migrator__.drop_not_null(meta.table_name, field.column_name)) return model def add_default(self, model: Union[str, TModelType], name: str, default: Any) -> TModelType: """Add default.""" model = self.__get_model__(model) meta = model._meta # type: ignore[] field = meta.fields[name] meta.defaults[field] = field.default = default self.__ops__.append(self.__migrator__.apply_default(meta.table_name, name, field)) return model