You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

65 lines
1.8 KiB
Python

3 years ago
"""Test how the ufuncs in special handle nan inputs.
"""
from typing import Callable, Dict
import numpy as np
from numpy.testing import assert_array_equal, assert_, suppress_warnings
import pytest
import scipy.special as sc
KNOWNFAILURES: Dict[str, Callable] = {}
POSTPROCESSING: Dict[str, Callable] = {}
def _get_ufuncs():
ufuncs = []
ufunc_names = []
for name in sorted(sc.__dict__):
obj = sc.__dict__[name]
if not isinstance(obj, np.ufunc):
continue
msg = KNOWNFAILURES.get(obj)
if msg is None:
ufuncs.append(obj)
ufunc_names.append(name)
else:
fail = pytest.mark.xfail(run=False, reason=msg)
ufuncs.append(pytest.param(obj, marks=fail))
ufunc_names.append(name)
return ufuncs, ufunc_names
UFUNCS, UFUNC_NAMES = _get_ufuncs()
@pytest.mark.parametrize("func", UFUNCS, ids=UFUNC_NAMES)
def test_nan_inputs(func):
args = (np.nan,)*func.nin
with suppress_warnings() as sup:
# Ignore warnings about unsafe casts from legacy wrappers
sup.filter(RuntimeWarning,
"floating point number truncated to an integer")
try:
with suppress_warnings() as sup:
sup.filter(DeprecationWarning)
res = func(*args)
except TypeError:
# One of the arguments doesn't take real inputs
return
if func in POSTPROCESSING:
res = POSTPROCESSING[func](*res)
msg = "got {} instead of nan".format(res)
assert_array_equal(np.isnan(res), True, err_msg=msg)
def test_legacy_cast():
with suppress_warnings() as sup:
sup.filter(RuntimeWarning,
"floating point number truncated to an integer")
res = sc.bdtrc(np.nan, 1, 0.5)
assert_(np.isnan(res))