"""vendored pytestplugin functions from the most recent SQLAlchemy versions. Alembic tests need to run on older versions of SQLAlchemy that don't necessarily have all the latest testing fixtures. """ try: # installed by bootstrap.py import sqla_plugin_base as plugin_base except ImportError: # assume we're a package, use traditional import from . import plugin_base import inspect import itertools import operator import os import re import sys import pytest from sqlalchemy.testing.plugin.pytestplugin import * # noqa from sqlalchemy.testing.plugin.pytestplugin import pytest_configure as spc # override selected SQLAlchemy pytest hooks with vendored functionality def pytest_configure(config): spc(config) plugin_base.set_fixture_functions(PytestFixtureFunctions) def pytest_pycollect_makeitem(collector, name, obj): if inspect.isclass(obj) and plugin_base.want_class(name, obj): # in pytest 5.4.0 # return [ # pytest.Class.from_parent(collector, # name=parametrize_cls.__name__) # for parametrize_cls in _parametrize_cls(collector.module, obj) # ] return [ pytest.Class(parametrize_cls.__name__, parent=collector) for parametrize_cls in _parametrize_cls(collector.module, obj) ] elif ( inspect.isfunction(obj) and isinstance(collector, pytest.Instance) and plugin_base.want_method(collector.cls, obj) ): # None means, fall back to default logic, which includes # method-level parametrize return None else: # empty list means skip this item return [] _current_class = None def _parametrize_cls(module, cls): """implement a class-based version of pytest parametrize.""" if "_sa_parametrize" not in cls.__dict__: return [cls] _sa_parametrize = cls._sa_parametrize classes = [] for full_param_set in itertools.product( *[params for argname, params in _sa_parametrize] ): cls_variables = {} for argname, param in zip( [_sa_param[0] for _sa_param in _sa_parametrize], full_param_set ): if not argname: raise TypeError("need argnames for class-based combinations") argname_split = re.split(r",\s*", argname) for arg, val in zip(argname_split, param.values): cls_variables[arg] = val parametrized_name = "_".join( # token is a string, but in py2k py.test is giving us a unicode, # so call str() on it. str(re.sub(r"\W", "", token)) for param in full_param_set for token in param.id.split("-") ) name = "%s_%s" % (cls.__name__, parametrized_name) newcls = type.__new__(type, name, (cls,), cls_variables) setattr(module, name, newcls) classes.append(newcls) return classes def getargspec(fn): if sys.version_info.major == 3: return inspect.getfullargspec(fn) else: return inspect.getargspec(fn) class PytestFixtureFunctions(plugin_base.FixtureFunctions): def skip_test_exception(self, *arg, **kw): return pytest.skip.Exception(*arg, **kw) _combination_id_fns = { "i": lambda obj: obj, "r": repr, "s": str, "n": operator.attrgetter("__name__"), } def combinations(self, *arg_sets, **kw): """facade for pytest.mark.paramtrize. Automatically derives argument names from the callable which in our case is always a method on a class with positional arguments. ids for parameter sets are derived using an optional template. """ from alembic.testing import exclusions if sys.version_info.major == 3: if len(arg_sets) == 1 and hasattr(arg_sets[0], "__next__"): arg_sets = list(arg_sets[0]) else: if len(arg_sets) == 1 and hasattr(arg_sets[0], "next"): arg_sets = list(arg_sets[0]) argnames = kw.pop("argnames", None) exclusion_combinations = [] def _filter_exclusions(args): result = [] gathered_exclusions = [] for a in args: if isinstance(a, exclusions.compound): gathered_exclusions.append(a) else: result.append(a) exclusion_combinations.extend( [(exclusion, result) for exclusion in gathered_exclusions] ) return result id_ = kw.pop("id_", None) if id_: _combination_id_fns = self._combination_id_fns # because itemgetter is not consistent for one argument vs. # multiple, make it multiple in all cases and use a slice # to omit the first argument _arg_getter = operator.itemgetter( 0, *[ idx for idx, char in enumerate(id_) if char in ("n", "r", "s", "a") ] ) fns = [ (operator.itemgetter(idx), _combination_id_fns[char]) for idx, char in enumerate(id_) if char in _combination_id_fns ] arg_sets = [ pytest.param( *_arg_getter(_filter_exclusions(arg))[1:], id="-".join( comb_fn(getter(arg)) for getter, comb_fn in fns ) ) for arg in [ (arg,) if not isinstance(arg, tuple) else arg for arg in arg_sets ] ] else: # ensure using pytest.param so that even a 1-arg paramset # still needs to be a tuple. otherwise paramtrize tries to # interpret a single arg differently than tuple arg arg_sets = [ pytest.param(*_filter_exclusions(arg)) for arg in [ (arg,) if not isinstance(arg, tuple) else arg for arg in arg_sets ] ] def decorate(fn): if inspect.isclass(fn): if "_sa_parametrize" not in fn.__dict__: fn._sa_parametrize = [] fn._sa_parametrize.append((argnames, arg_sets)) return fn else: if argnames is None: _argnames = getargspec(fn).args[1:] else: _argnames = argnames if exclusion_combinations: for exclusion, combination in exclusion_combinations: combination_by_kw = { argname: val for argname, val in zip(_argnames, combination) } exclusion = exclusion.with_combination( **combination_by_kw ) fn = exclusion(fn) return pytest.mark.parametrize(_argnames, arg_sets)(fn) return decorate def param_ident(self, *parameters): ident = parameters[0] return pytest.param(*parameters[1:], id=ident) def fixture(self, *arg, **kw): return pytest.fixture(*arg, **kw) def get_current_test_name(self): return os.environ.get("PYTEST_CURRENT_TEST")