You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

65 lines
1.8 KiB
Python

"""Test how the ufuncs in special handle nan inputs.
"""
from typing import Callable, Dict
import numpy as np
from numpy.testing import assert_array_equal, assert_, suppress_warnings
import pytest
import scipy.special as sc
KNOWNFAILURES: Dict[str, Callable] = {}
POSTPROCESSING: Dict[str, Callable] = {}
def _get_ufuncs():
ufuncs = []
ufunc_names = []
for name in sorted(sc.__dict__):
obj = sc.__dict__[name]
if not isinstance(obj, np.ufunc):
continue
msg = KNOWNFAILURES.get(obj)
if msg is None:
ufuncs.append(obj)
ufunc_names.append(name)
else:
fail = pytest.mark.xfail(run=False, reason=msg)
ufuncs.append(pytest.param(obj, marks=fail))
ufunc_names.append(name)
return ufuncs, ufunc_names
UFUNCS, UFUNC_NAMES = _get_ufuncs()
@pytest.mark.parametrize("func", UFUNCS, ids=UFUNC_NAMES)
def test_nan_inputs(func):
args = (np.nan,)*func.nin
with suppress_warnings() as sup:
# Ignore warnings about unsafe casts from legacy wrappers
sup.filter(RuntimeWarning,
"floating point number truncated to an integer")
try:
with suppress_warnings() as sup:
sup.filter(DeprecationWarning)
res = func(*args)
except TypeError:
# One of the arguments doesn't take real inputs
return
if func in POSTPROCESSING:
res = POSTPROCESSING[func](*res)
msg = "got {} instead of nan".format(res)
assert_array_equal(np.isnan(res), True, err_msg=msg)
def test_legacy_cast():
with suppress_warnings() as sup:
sup.filter(RuntimeWarning,
"floating point number truncated to an integer")
res = sc.bdtrc(np.nan, 1, 0.5)
assert_(np.isnan(res))