You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

306 lines
8.5 KiB
Python

import re
from sqlalchemy import __version__
from sqlalchemy import inspect
from sqlalchemy import schema
from sqlalchemy import sql
from sqlalchemy import types as sqltypes
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.schema import CheckConstraint
from sqlalchemy.schema import Column
from sqlalchemy.schema import ForeignKeyConstraint
from sqlalchemy.sql.elements import quoted_name
from sqlalchemy.sql.expression import _BindParamClause
from sqlalchemy.sql.expression import _TextClause as TextClause
from sqlalchemy.sql.visitors import traverse
from . import compat
def _safe_int(value):
try:
return int(value)
except:
return value
_vers = tuple(
[_safe_int(x) for x in re.findall(r"(\d+|[abc]\d)", __version__)]
)
sqla_110 = _vers >= (1, 1, 0)
sqla_1115 = _vers >= (1, 1, 15)
sqla_120 = _vers >= (1, 2, 0)
sqla_1216 = _vers >= (1, 2, 16)
sqla_13 = _vers >= (1, 3)
sqla_14 = _vers >= (1, 4)
try:
from sqlalchemy import Computed # noqa
has_computed = True
has_computed_reflection = _vers >= (1, 3, 16)
except ImportError:
has_computed = False
has_computed_reflection = False
AUTOINCREMENT_DEFAULT = "auto"
def _connectable_has_table(connectable, tablename, schemaname):
if sqla_14:
return inspect(connectable).has_table(tablename, schemaname)
else:
return connectable.dialect.has_table(
connectable, tablename, schemaname
)
def _exec_on_inspector(inspector, statement, **params):
if sqla_14:
with inspector._operation_context() as conn:
return conn.execute(statement, params)
else:
return inspector.bind.execute(statement, params)
def _server_default_is_computed(column):
if not has_computed:
return False
else:
return isinstance(column.computed, Computed)
def _table_for_constraint(constraint):
if isinstance(constraint, ForeignKeyConstraint):
return constraint.parent
else:
return constraint.table
def _columns_for_constraint(constraint):
if isinstance(constraint, ForeignKeyConstraint):
return [fk.parent for fk in constraint.elements]
elif isinstance(constraint, CheckConstraint):
return _find_columns(constraint.sqltext)
else:
return list(constraint.columns)
def _fk_spec(constraint):
source_columns = [
constraint.columns[key].name for key in constraint.column_keys
]
source_table = constraint.parent.name
source_schema = constraint.parent.schema
target_schema = constraint.elements[0].column.table.schema
target_table = constraint.elements[0].column.table.name
target_columns = [element.column.name for element in constraint.elements]
ondelete = constraint.ondelete
onupdate = constraint.onupdate
deferrable = constraint.deferrable
initially = constraint.initially
return (
source_schema,
source_table,
source_columns,
target_schema,
target_table,
target_columns,
onupdate,
ondelete,
deferrable,
initially,
)
def _fk_is_self_referential(constraint):
spec = constraint.elements[0]._get_colspec()
tokens = spec.split(".")
tokens.pop(-1) # colname
tablekey = ".".join(tokens)
return tablekey == constraint.parent.key
def _is_type_bound(constraint):
# this deals with SQLAlchemy #3260, don't copy CHECK constraints
# that will be generated by the type.
# new feature added for #3260
return constraint._type_bound
def _find_columns(clause):
"""locate Column objects within the given expression."""
cols = set()
traverse(clause, {}, {"column": cols.add})
return cols
def _remove_column_from_collection(collection, column):
"""remove a column from a ColumnCollection."""
# workaround for older SQLAlchemy, remove the
# same object that's present
to_remove = collection[column.key]
collection.remove(to_remove)
def _textual_index_column(table, text_):
"""a workaround for the Index construct's severe lack of flexibility"""
if isinstance(text_, compat.string_types):
c = Column(text_, sqltypes.NULLTYPE)
table.append_column(c)
return c
elif isinstance(text_, TextClause):
return _textual_index_element(table, text_)
else:
raise ValueError("String or text() construct expected")
class _textual_index_element(sql.ColumnElement):
"""Wrap around a sqlalchemy text() construct in such a way that
we appear like a column-oriented SQL expression to an Index
construct.
The issue here is that currently the Postgresql dialect, the biggest
recipient of functional indexes, keys all the index expressions to
the corresponding column expressions when rendering CREATE INDEX,
so the Index we create here needs to have a .columns collection that
is the same length as the .expressions collection. Ultimately
SQLAlchemy should support text() expressions in indexes.
See SQLAlchemy issue 3174.
"""
__visit_name__ = "_textual_idx_element"
def __init__(self, table, text):
self.table = table
self.text = text
self.key = text.text
self.fake_column = schema.Column(self.text.text, sqltypes.NULLTYPE)
table.append_column(self.fake_column)
def get_children(self):
return [self.fake_column]
@compiles(_textual_index_element)
def _render_textual_index_column(element, compiler, **kw):
return compiler.process(element.text, **kw)
class _literal_bindparam(_BindParamClause):
pass
@compiles(_literal_bindparam)
def _render_literal_bindparam(element, compiler, **kw):
return compiler.render_literal_bindparam(element, **kw)
def _get_index_expressions(idx):
return list(idx.expressions)
def _get_index_column_names(idx):
return [getattr(exp, "name", None) for exp in _get_index_expressions(idx)]
def _column_kwargs(col):
if sqla_13:
return col.kwargs
else:
return {}
def _get_constraint_final_name(constraint, dialect):
if constraint.name is None:
return None
elif sqla_14:
# for SQLAlchemy 1.4 we would like to have the option to expand
# the use of "deferred" names for constraints as well as to have
# some flexibility with "None" name and similar; make use of new
# SQLAlchemy API to return what would be the final compiled form of
# the name for this dialect.
return dialect.identifier_preparer.format_constraint(
constraint, _alembic_quote=False
)
else:
# prior to SQLAlchemy 1.4, work around quoting logic to get at the
# final compiled name without quotes.
if hasattr(constraint.name, "quote"):
# might be quoted_name, might be truncated_name, keep it the
# same
quoted_name_cls = type(constraint.name)
else:
quoted_name_cls = quoted_name
new_name = quoted_name_cls(str(constraint.name), quote=False)
constraint = constraint.__class__(name=new_name)
if isinstance(constraint, schema.Index):
# name should not be quoted.
return dialect.ddl_compiler(dialect, None)._prepared_index_name(
constraint
)
else:
# name should not be quoted.
return dialect.identifier_preparer.format_constraint(constraint)
def _constraint_is_named(constraint, dialect):
if sqla_14:
if constraint.name is None:
return False
name = dialect.identifier_preparer.format_constraint(
constraint, _alembic_quote=False
)
return name is not None
else:
return constraint.name is not None
def _dialect_supports_comments(dialect):
if sqla_120:
return dialect.supports_comments
else:
return False
def _comment_attribute(obj):
"""return the .comment attribute from a Table or Column"""
if sqla_120:
return obj.comment
else:
return None
def _is_mariadb(mysql_dialect):
return (
mysql_dialect.server_version_info
and "MariaDB" in mysql_dialect.server_version_info
)
def _mariadb_normalized_version_info(mysql_dialect):
if len(mysql_dialect.server_version_info) > 5:
return mysql_dialect.server_version_info[3:]
else:
return mysql_dialect.server_version_info
if sqla_14:
from sqlalchemy import create_mock_engine
else:
from sqlalchemy import create_engine
def create_mock_engine(url, executor):
return create_engine(
"postgresql://", strategy="mock", executor=executor
)