You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

485 lines
14 KiB
Python

# testing/exclusions.py
# Copyright (C) 2005-2019 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
import contextlib
import operator
import re
from sqlalchemy import util as sqla_util
from sqlalchemy.util import decorator
from . import config
from . import fixture_functions
from .. import util
from ..util.compat import inspect_getargspec
def skip_if(predicate, reason=None):
rule = compound()
pred = _as_predicate(predicate, reason)
rule.skips.add(pred)
return rule
def fails_if(predicate, reason=None):
rule = compound()
pred = _as_predicate(predicate, reason)
rule.fails.add(pred)
return rule
class compound(object):
def __init__(self):
self.fails = set()
self.skips = set()
self.tags = set()
self.combinations = {}
def __add__(self, other):
return self.add(other)
def with_combination(self, **kw):
copy = compound()
copy.fails.update(self.fails)
copy.skips.update(self.skips)
copy.tags.update(self.tags)
copy.combinations.update((f, kw) for f in copy.fails)
copy.combinations.update((s, kw) for s in copy.skips)
return copy
def add(self, *others):
copy = compound()
copy.fails.update(self.fails)
copy.skips.update(self.skips)
copy.tags.update(self.tags)
for other in others:
copy.fails.update(other.fails)
copy.skips.update(other.skips)
copy.tags.update(other.tags)
return copy
def not_(self):
copy = compound()
copy.fails.update(NotPredicate(fail) for fail in self.fails)
copy.skips.update(NotPredicate(skip) for skip in self.skips)
copy.tags.update(self.tags)
return copy
@property
def enabled(self):
return self.enabled_for_config(config._current)
def enabled_for_config(self, config):
for predicate in self.skips.union(self.fails):
if predicate(config):
return False
else:
return True
def matching_config_reasons(self, config):
return [
predicate._as_string(config)
for predicate in self.skips.union(self.fails)
if predicate(config)
]
def include_test(self, include_tags, exclude_tags):
return bool(
not self.tags.intersection(exclude_tags)
and (not include_tags or self.tags.intersection(include_tags))
)
def _extend(self, other):
self.skips.update(other.skips)
self.fails.update(other.fails)
self.tags.update(other.tags)
self.combinations.update(other.combinations)
def __call__(self, fn):
if hasattr(fn, "_sa_exclusion_extend"):
fn._sa_exclusion_extend._extend(self)
return fn
@decorator
def decorate(fn, *args, **kw):
return self._do(config._current, fn, *args, **kw)
decorated = decorate(fn)
decorated._sa_exclusion_extend = self
return decorated
@contextlib.contextmanager
def fail_if(self):
all_fails = compound()
all_fails.fails.update(self.skips.union(self.fails))
try:
yield
except Exception as ex:
all_fails._expect_failure(config._current, ex, None)
else:
all_fails._expect_success(config._current, None)
def _check_combinations(self, combination, predicate):
if predicate in self.combinations:
for k, v in combination:
if (
k in self.combinations[predicate]
and self.combinations[predicate][k] != v
):
return False
return True
def _do(self, cfg, fn, *args, **kw):
if len(args) > 1:
insp = inspect_getargspec(fn)
combination = list(zip(insp.args[1:], args[1:]))
else:
combination = None
for skip in self.skips:
if self._check_combinations(combination, skip) and skip(cfg):
msg = "'%s' : %s" % (
fixture_functions.get_current_test_name(),
skip._as_string(cfg),
)
config.skip_test(msg)
try:
return_value = fn(*args, **kw)
except Exception as ex:
self._expect_failure(cfg, ex, combination, name=fn.__name__)
else:
self._expect_success(cfg, combination, name=fn.__name__)
return return_value
def _expect_failure(self, config, ex, combination, name="block"):
for fail in self.fails:
if self._check_combinations(combination, fail) and fail(config):
if sqla_util.py2k:
str_ex = unicode(ex).encode( # noqa: F821
"utf-8", errors="ignore"
)
else:
str_ex = str(ex)
print(
(
"%s failed as expected (%s): %s "
% (name, fail._as_string(config), str_ex)
)
)
break
else:
util.raise_from_cause(ex)
def _expect_success(self, config, combination, name="block"):
if not self.fails:
return
for fail in self.fails:
if self._check_combinations(combination, fail) and fail(config):
raise AssertionError(
"Unexpected success for '%s' (%s)"
% (
name,
" and ".join(
fail._as_string(config) for fail in self.fails
),
)
)
def requires_tag(tagname):
return tags([tagname])
def tags(tagnames):
comp = compound()
comp.tags.update(tagnames)
return comp
def only_if(predicate, reason=None):
predicate = _as_predicate(predicate)
return skip_if(NotPredicate(predicate), reason)
def succeeds_if(predicate, reason=None):
predicate = _as_predicate(predicate)
return fails_if(NotPredicate(predicate), reason)
class Predicate(object):
@classmethod
def as_predicate(cls, predicate, description=None):
if isinstance(predicate, compound):
return cls.as_predicate(predicate.enabled_for_config, description)
elif isinstance(predicate, Predicate):
if description and predicate.description is None:
predicate.description = description
return predicate
elif isinstance(predicate, (list, set)):
return OrPredicate(
[cls.as_predicate(pred) for pred in predicate], description
)
elif isinstance(predicate, tuple):
return SpecPredicate(*predicate)
elif isinstance(predicate, sqla_util.string_types):
tokens = re.match(
r"([\+\w]+)\s*(?:(>=|==|!=|<=|<|>)\s*([\d\.]+))?", predicate
)
if not tokens:
raise ValueError(
"Couldn't locate DB name in predicate: %r" % predicate
)
db = tokens.group(1)
op = tokens.group(2)
spec = (
tuple(int(d) for d in tokens.group(3).split("."))
if tokens.group(3)
else None
)
return SpecPredicate(db, op, spec, description=description)
elif callable(predicate):
return LambdaPredicate(predicate, description)
else:
assert False, "unknown predicate type: %s" % predicate
def _format_description(self, config, negate=False):
bool_ = self(config)
if negate:
bool_ = not negate
return self.description % {
"driver": config.db.url.get_driver_name()
if config
else "<no driver>",
"database": config.db.url.get_backend_name()
if config
else "<no database>",
"doesnt_support": "doesn't support" if bool_ else "does support",
"does_support": "does support" if bool_ else "doesn't support",
}
def _as_string(self, config=None, negate=False):
raise NotImplementedError()
class BooleanPredicate(Predicate):
def __init__(self, value, description=None):
self.value = value
self.description = description or "boolean %s" % value
def __call__(self, config):
return self.value
def _as_string(self, config, negate=False):
return self._format_description(config, negate=negate)
class SpecPredicate(Predicate):
def __init__(self, db, op=None, spec=None, description=None):
self.db = db
self.op = op
self.spec = spec
self.description = description
_ops = {
"<": operator.lt,
">": operator.gt,
"==": operator.eq,
"!=": operator.ne,
"<=": operator.le,
">=": operator.ge,
"in": operator.contains,
"between": lambda val, pair: val >= pair[0] and val <= pair[1],
}
def __call__(self, config):
engine = config.db
if "+" in self.db:
dialect, driver = self.db.split("+")
else:
dialect, driver = self.db, None
if dialect and engine.name != dialect:
return False
if driver is not None and engine.driver != driver:
return False
if self.op is not None:
assert driver is None, "DBAPI version specs not supported yet"
version = _server_version(engine)
oper = (
hasattr(self.op, "__call__") and self.op or self._ops[self.op]
)
return oper(version, self.spec)
else:
return True
def _as_string(self, config, negate=False):
if self.description is not None:
return self._format_description(config)
elif self.op is None:
if negate:
return "not %s" % self.db
else:
return "%s" % self.db
else:
if negate:
return "not %s %s %s" % (self.db, self.op, self.spec)
else:
return "%s %s %s" % (self.db, self.op, self.spec)
class LambdaPredicate(Predicate):
def __init__(self, lambda_, description=None, args=None, kw=None):
spec = inspect_getargspec(lambda_)
if not spec[0]:
self.lambda_ = lambda db: lambda_()
else:
self.lambda_ = lambda_
self.args = args or ()
self.kw = kw or {}
if description:
self.description = description
elif lambda_.__doc__:
self.description = lambda_.__doc__
else:
self.description = "custom function"
def __call__(self, config):
return self.lambda_(config)
def _as_string(self, config, negate=False):
return self._format_description(config)
class NotPredicate(Predicate):
def __init__(self, predicate, description=None):
self.predicate = predicate
self.description = description
def __call__(self, config):
return not self.predicate(config)
def _as_string(self, config, negate=False):
if self.description:
return self._format_description(config, not negate)
else:
return self.predicate._as_string(config, not negate)
class OrPredicate(Predicate):
def __init__(self, predicates, description=None):
self.predicates = predicates
self.description = description
def __call__(self, config):
for pred in self.predicates:
if pred(config):
return True
return False
def _eval_str(self, config, negate=False):
if negate:
conjunction = " and "
else:
conjunction = " or "
return conjunction.join(
p._as_string(config, negate=negate) for p in self.predicates
)
def _negation_str(self, config):
if self.description is not None:
return "Not " + self._format_description(config)
else:
return self._eval_str(config, negate=True)
def _as_string(self, config, negate=False):
if negate:
return self._negation_str(config)
else:
if self.description is not None:
return self._format_description(config)
else:
return self._eval_str(config)
_as_predicate = Predicate.as_predicate
def _is_excluded(db, op, spec):
return SpecPredicate(db, op, spec)(config._current)
def _server_version(engine):
"""Return a server_version_info tuple."""
# force metadata to be retrieved
conn = engine.connect()
version = getattr(engine.dialect, "server_version_info", None)
if version is None:
version = ()
conn.close()
return version
def db_spec(*dbs):
return OrPredicate([Predicate.as_predicate(db) for db in dbs])
def open(): # noqa
return skip_if(BooleanPredicate(False, "mark as execute"))
def closed():
return skip_if(BooleanPredicate(True, "marked as skip"))
def fails(reason=None):
return fails_if(BooleanPredicate(True, reason or "expected to fail"))
@decorator
def future(fn, *arg):
return fails_if(LambdaPredicate(fn), "Future feature")
def fails_on(db, reason=None):
return fails_if(db, reason)
def fails_on_everything_except(*dbs):
return succeeds_if(OrPredicate([Predicate.as_predicate(db) for db in dbs]))
def skip(db, reason=None):
return skip_if(db, reason)
def only_on(dbs, reason=None):
return only_if(
OrPredicate(
[Predicate.as_predicate(db, reason) for db in util.to_list(dbs)]
)
)
def exclude(db, op, spec, reason=None):
return skip_if(SpecPredicate(db, op, spec), reason)
def against(config, *queries):
assert queries, "no queries sent!"
return OrPredicate([Predicate.as_predicate(query) for query in queries])(
config
)