# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls # mypy: no-warn-return-any, allow-any-generics from __future__ import annotations from typing import Any from typing import ClassVar from typing import Dict from typing import Generic from typing import NamedTuple from typing import Optional from typing import Sequence from typing import Tuple from typing import Type from typing import TYPE_CHECKING from typing import TypeVar from typing import Union from sqlalchemy.sql.schema import Constraint from sqlalchemy.sql.schema import ForeignKeyConstraint from sqlalchemy.sql.schema import Index from sqlalchemy.sql.schema import UniqueConstraint from typing_extensions import TypeGuard from .. import util from ..util import sqla_compat if TYPE_CHECKING: from typing import Literal from alembic.autogenerate.api import AutogenContext from alembic.ddl.impl import DefaultImpl CompareConstraintType = Union[Constraint, Index] _C = TypeVar("_C", bound=CompareConstraintType) _clsreg: Dict[str, Type[_constraint_sig]] = {} class ComparisonResult(NamedTuple): status: Literal["equal", "different", "skip"] message: str @property def is_equal(self) -> bool: return self.status == "equal" @property def is_different(self) -> bool: return self.status == "different" @property def is_skip(self) -> bool: return self.status == "skip" @classmethod def Equal(cls) -> ComparisonResult: """the constraints are equal.""" return cls("equal", "The two constraints are equal") @classmethod def Different(cls, reason: Union[str, Sequence[str]]) -> ComparisonResult: """the constraints are different for the provided reason(s).""" return cls("different", ", ".join(util.to_list(reason))) @classmethod def Skip(cls, reason: Union[str, Sequence[str]]) -> ComparisonResult: """the constraint cannot be compared for the provided reason(s). The message is logged, but the constraints will be otherwise considered equal, meaning that no migration command will be generated. """ return cls("skip", ", ".join(util.to_list(reason))) class _constraint_sig(Generic[_C]): const: _C _sig: Tuple[Any, ...] name: Optional[sqla_compat._ConstraintNameDefined] impl: DefaultImpl _is_index: ClassVar[bool] = False _is_fk: ClassVar[bool] = False _is_uq: ClassVar[bool] = False _is_metadata: bool def __init_subclass__(cls) -> None: cls._register() @classmethod def _register(cls): raise NotImplementedError() def __init__( self, is_metadata: bool, impl: DefaultImpl, const: _C ) -> None: raise NotImplementedError() def compare_to_reflected( self, other: _constraint_sig[Any] ) -> ComparisonResult: assert self.impl is other.impl assert self._is_metadata assert not other._is_metadata return self._compare_to_reflected(other) def _compare_to_reflected( self, other: _constraint_sig[_C] ) -> ComparisonResult: raise NotImplementedError() @classmethod def from_constraint( cls, is_metadata: bool, impl: DefaultImpl, constraint: _C ) -> _constraint_sig[_C]: # these could be cached by constraint/impl, however, if the # constraint is modified in place, then the sig is wrong. the mysql # impl currently does this, and if we fixed that we can't be sure # someone else might do it too, so play it safe. sig = _clsreg[constraint.__visit_name__](is_metadata, impl, constraint) return sig def md_name_to_sql_name(self, context: AutogenContext) -> Optional[str]: return sqla_compat._get_constraint_final_name( self.const, context.dialect ) @util.memoized_property def is_named(self): return sqla_compat._constraint_is_named(self.const, self.impl.dialect) @util.memoized_property def unnamed(self) -> Tuple[Any, ...]: return self._sig @util.memoized_property def unnamed_no_options(self) -> Tuple[Any, ...]: raise NotImplementedError() @util.memoized_property def _full_sig(self) -> Tuple[Any, ...]: return (self.name,) + self.unnamed def __eq__(self, other) -> bool: return self._full_sig == other._full_sig def __ne__(self, other) -> bool: return self._full_sig != other._full_sig def __hash__(self) -> int: return hash(self._full_sig) class _uq_constraint_sig(_constraint_sig[UniqueConstraint]): _is_uq = True @classmethod def _register(cls) -> None: _clsreg["unique_constraint"] = cls is_unique = True def __init__( self, is_metadata: bool, impl: DefaultImpl, const: UniqueConstraint, ) -> None: self.impl = impl self.const = const self.name = sqla_compat.constraint_name_or_none(const.name) self._sig = tuple(sorted([col.name for col in const.columns])) self._is_metadata = is_metadata @property def column_names(self) -> Tuple[str, ...]: return tuple([col.name for col in self.const.columns]) def _compare_to_reflected( self, other: _constraint_sig[_C] ) -> ComparisonResult: assert self._is_metadata metadata_obj = self conn_obj = other assert is_uq_sig(conn_obj) return self.impl.compare_unique_constraint( metadata_obj.const, conn_obj.const ) class _ix_constraint_sig(_constraint_sig[Index]): _is_index = True name: sqla_compat._ConstraintName @classmethod def _register(cls) -> None: _clsreg["index"] = cls def __init__( self, is_metadata: bool, impl: DefaultImpl, const: Index ) -> None: self.impl = impl self.const = const self.name = const.name self.is_unique = bool(const.unique) self._is_metadata = is_metadata def _compare_to_reflected( self, other: _constraint_sig[_C] ) -> ComparisonResult: assert self._is_metadata metadata_obj = self conn_obj = other assert is_index_sig(conn_obj) return self.impl.compare_indexes(metadata_obj.const, conn_obj.const) @util.memoized_property def has_expressions(self): return sqla_compat.is_expression_index(self.const) @util.memoized_property def column_names(self) -> Tuple[str, ...]: return tuple([col.name for col in self.const.columns]) @util.memoized_property def column_names_optional(self) -> Tuple[Optional[str], ...]: return tuple( [getattr(col, "name", None) for col in self.const.expressions] ) @util.memoized_property def is_named(self): return True @util.memoized_property def unnamed(self): return (self.is_unique,) + self.column_names_optional class _fk_constraint_sig(_constraint_sig[ForeignKeyConstraint]): _is_fk = True @classmethod def _register(cls) -> None: _clsreg["foreign_key_constraint"] = cls def __init__( self, is_metadata: bool, impl: DefaultImpl, const: ForeignKeyConstraint, ) -> None: self._is_metadata = is_metadata self.impl = impl self.const = const self.name = sqla_compat.constraint_name_or_none(const.name) ( self.source_schema, self.source_table, self.source_columns, self.target_schema, self.target_table, self.target_columns, onupdate, ondelete, deferrable, initially, ) = sqla_compat._fk_spec(const) self._sig: Tuple[Any, ...] = ( self.source_schema, self.source_table, tuple(self.source_columns), self.target_schema, self.target_table, tuple(self.target_columns), ) + ( ( (None if onupdate.lower() == "no action" else onupdate.lower()) if onupdate else None ), ( (None if ondelete.lower() == "no action" else ondelete.lower()) if ondelete else None ), # convert initially + deferrable into one three-state value ( "initially_deferrable" if initially and initially.lower() == "deferred" else "deferrable" if deferrable else "not deferrable" ), ) @util.memoized_property def unnamed_no_options(self): return ( self.source_schema, self.source_table, tuple(self.source_columns), self.target_schema, self.target_table, tuple(self.target_columns), ) def is_index_sig(sig: _constraint_sig) -> TypeGuard[_ix_constraint_sig]: return sig._is_index def is_uq_sig(sig: _constraint_sig) -> TypeGuard[_uq_constraint_sig]: return sig._is_uq def is_fk_sig(sig: _constraint_sig) -> TypeGuard[_fk_constraint_sig]: return sig._is_fk