You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

233 lines
7.4 KiB
Python

"""vendored pytestplugin functions from the most recent SQLAlchemy versions.
Alembic tests need to run on older versions of SQLAlchemy that don't
necessarily have all the latest testing fixtures.
"""
try:
# installed by bootstrap.py
import sqla_plugin_base as plugin_base
except ImportError:
# assume we're a package, use traditional import
from . import plugin_base
import inspect
import itertools
import operator
import os
import re
import sys
import pytest
from sqlalchemy.testing.plugin.pytestplugin import * # noqa
from sqlalchemy.testing.plugin.pytestplugin import pytest_configure as spc
# override selected SQLAlchemy pytest hooks with vendored functionality
def pytest_configure(config):
spc(config)
plugin_base.set_fixture_functions(PytestFixtureFunctions)
def pytest_pycollect_makeitem(collector, name, obj):
if inspect.isclass(obj) and plugin_base.want_class(name, obj):
# in pytest 5.4.0
# return [
# pytest.Class.from_parent(collector,
# name=parametrize_cls.__name__)
# for parametrize_cls in _parametrize_cls(collector.module, obj)
# ]
return [
pytest.Class(parametrize_cls.__name__, parent=collector)
for parametrize_cls in _parametrize_cls(collector.module, obj)
]
elif (
inspect.isfunction(obj)
and isinstance(collector, pytest.Instance)
and plugin_base.want_method(collector.cls, obj)
):
# None means, fall back to default logic, which includes
# method-level parametrize
return None
else:
# empty list means skip this item
return []
_current_class = None
def _parametrize_cls(module, cls):
"""implement a class-based version of pytest parametrize."""
if "_sa_parametrize" not in cls.__dict__:
return [cls]
_sa_parametrize = cls._sa_parametrize
classes = []
for full_param_set in itertools.product(
*[params for argname, params in _sa_parametrize]
):
cls_variables = {}
for argname, param in zip(
[_sa_param[0] for _sa_param in _sa_parametrize], full_param_set
):
if not argname:
raise TypeError("need argnames for class-based combinations")
argname_split = re.split(r",\s*", argname)
for arg, val in zip(argname_split, param.values):
cls_variables[arg] = val
parametrized_name = "_".join(
# token is a string, but in py2k py.test is giving us a unicode,
# so call str() on it.
str(re.sub(r"\W", "", token))
for param in full_param_set
for token in param.id.split("-")
)
name = "%s_%s" % (cls.__name__, parametrized_name)
newcls = type.__new__(type, name, (cls,), cls_variables)
setattr(module, name, newcls)
classes.append(newcls)
return classes
def getargspec(fn):
if sys.version_info.major == 3:
return inspect.getfullargspec(fn)
else:
return inspect.getargspec(fn)
class PytestFixtureFunctions(plugin_base.FixtureFunctions):
def skip_test_exception(self, *arg, **kw):
return pytest.skip.Exception(*arg, **kw)
_combination_id_fns = {
"i": lambda obj: obj,
"r": repr,
"s": str,
"n": operator.attrgetter("__name__"),
}
def combinations(self, *arg_sets, **kw):
"""facade for pytest.mark.paramtrize.
Automatically derives argument names from the callable which in our
case is always a method on a class with positional arguments.
ids for parameter sets are derived using an optional template.
"""
from alembic.testing import exclusions
if sys.version_info.major == 3:
if len(arg_sets) == 1 and hasattr(arg_sets[0], "__next__"):
arg_sets = list(arg_sets[0])
else:
if len(arg_sets) == 1 and hasattr(arg_sets[0], "next"):
arg_sets = list(arg_sets[0])
argnames = kw.pop("argnames", None)
exclusion_combinations = []
def _filter_exclusions(args):
result = []
gathered_exclusions = []
for a in args:
if isinstance(a, exclusions.compound):
gathered_exclusions.append(a)
else:
result.append(a)
exclusion_combinations.extend(
[(exclusion, result) for exclusion in gathered_exclusions]
)
return result
id_ = kw.pop("id_", None)
if id_:
_combination_id_fns = self._combination_id_fns
# because itemgetter is not consistent for one argument vs.
# multiple, make it multiple in all cases and use a slice
# to omit the first argument
_arg_getter = operator.itemgetter(
0,
*[
idx
for idx, char in enumerate(id_)
if char in ("n", "r", "s", "a")
]
)
fns = [
(operator.itemgetter(idx), _combination_id_fns[char])
for idx, char in enumerate(id_)
if char in _combination_id_fns
]
arg_sets = [
pytest.param(
*_arg_getter(_filter_exclusions(arg))[1:],
id="-".join(
comb_fn(getter(arg)) for getter, comb_fn in fns
)
)
for arg in [
(arg,) if not isinstance(arg, tuple) else arg
for arg in arg_sets
]
]
else:
# ensure using pytest.param so that even a 1-arg paramset
# still needs to be a tuple. otherwise paramtrize tries to
# interpret a single arg differently than tuple arg
arg_sets = [
pytest.param(*_filter_exclusions(arg))
for arg in [
(arg,) if not isinstance(arg, tuple) else arg
for arg in arg_sets
]
]
def decorate(fn):
if inspect.isclass(fn):
if "_sa_parametrize" not in fn.__dict__:
fn._sa_parametrize = []
fn._sa_parametrize.append((argnames, arg_sets))
return fn
else:
if argnames is None:
_argnames = getargspec(fn).args[1:]
else:
_argnames = argnames
if exclusion_combinations:
for exclusion, combination in exclusion_combinations:
combination_by_kw = {
argname: val
for argname, val in zip(_argnames, combination)
}
exclusion = exclusion.with_combination(
**combination_by_kw
)
fn = exclusion(fn)
return pytest.mark.parametrize(_argnames, arg_sets)(fn)
return decorate
def param_ident(self, *parameters):
ident = parameters[0]
return pytest.param(*parameters[1:], id=ident)
def fixture(self, *arg, **kw):
return pytest.fixture(*arg, **kw)
def get_current_test_name(self):
return os.environ.get("PYTEST_CURRENT_TEST")