You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
488 lines
16 KiB
Python
488 lines
16 KiB
Python
from os.path import join, dirname
|
|
|
|
import numpy as np
|
|
from numpy.testing import (
|
|
assert_array_almost_equal, assert_equal, assert_allclose)
|
|
import pytest
|
|
from pytest import raises as assert_raises
|
|
|
|
from scipy.fft._pocketfft.realtransforms import (
|
|
dct, idct, dst, idst, dctn, idctn, dstn, idstn)
|
|
|
|
fftpack_test_dir = join(dirname(__file__), '..', '..', '..', 'fftpack', 'tests')
|
|
|
|
MDATA_COUNT = 8
|
|
FFTWDATA_COUNT = 14
|
|
|
|
def is_longdouble_binary_compatible():
|
|
try:
|
|
one = np.frombuffer(
|
|
b'\x00\x00\x00\x00\x00\x00\x00\x80\xff\x3f\x00\x00\x00\x00\x00\x00',
|
|
dtype='<f16')
|
|
return one == np.longfloat(1.)
|
|
except TypeError:
|
|
return False
|
|
|
|
|
|
def get_reference_data():
|
|
ref = getattr(globals(), '__reference_data', None)
|
|
if ref is not None:
|
|
return ref
|
|
|
|
# Matlab reference data
|
|
MDATA = np.load(join(fftpack_test_dir, 'test.npz'))
|
|
X = [MDATA['x%d' % i] for i in range(MDATA_COUNT)]
|
|
Y = [MDATA['y%d' % i] for i in range(MDATA_COUNT)]
|
|
|
|
# FFTW reference data: the data are organized as follows:
|
|
# * SIZES is an array containing all available sizes
|
|
# * for every type (1, 2, 3, 4) and every size, the array dct_type_size
|
|
# contains the output of the DCT applied to the input np.linspace(0, size-1,
|
|
# size)
|
|
FFTWDATA_DOUBLE = np.load(join(fftpack_test_dir, 'fftw_double_ref.npz'))
|
|
FFTWDATA_SINGLE = np.load(join(fftpack_test_dir, 'fftw_single_ref.npz'))
|
|
FFTWDATA_SIZES = FFTWDATA_DOUBLE['sizes']
|
|
assert len(FFTWDATA_SIZES) == FFTWDATA_COUNT
|
|
|
|
if is_longdouble_binary_compatible():
|
|
FFTWDATA_LONGDOUBLE = np.load(
|
|
join(fftpack_test_dir, 'fftw_longdouble_ref.npz'))
|
|
else:
|
|
FFTWDATA_LONGDOUBLE = {k: v.astype(np.longfloat)
|
|
for k,v in FFTWDATA_DOUBLE.items()}
|
|
|
|
ref = {
|
|
'FFTWDATA_LONGDOUBLE': FFTWDATA_LONGDOUBLE,
|
|
'FFTWDATA_DOUBLE': FFTWDATA_DOUBLE,
|
|
'FFTWDATA_SINGLE': FFTWDATA_SINGLE,
|
|
'FFTWDATA_SIZES': FFTWDATA_SIZES,
|
|
'X': X,
|
|
'Y': Y
|
|
}
|
|
|
|
globals()['__reference_data'] = ref
|
|
return ref
|
|
|
|
|
|
@pytest.fixture(params=range(FFTWDATA_COUNT))
|
|
def fftwdata_size(request):
|
|
return get_reference_data()['FFTWDATA_SIZES'][request.param]
|
|
|
|
@pytest.fixture(params=range(MDATA_COUNT))
|
|
def mdata_x(request):
|
|
return get_reference_data()['X'][request.param]
|
|
|
|
|
|
@pytest.fixture(params=range(MDATA_COUNT))
|
|
def mdata_xy(request):
|
|
ref = get_reference_data()
|
|
y = ref['Y'][request.param]
|
|
x = ref['X'][request.param]
|
|
return x, y
|
|
|
|
|
|
def fftw_dct_ref(type, size, dt):
|
|
x = np.linspace(0, size-1, size).astype(dt)
|
|
dt = np.result_type(np.float32, dt)
|
|
if dt == np.double:
|
|
data = get_reference_data()['FFTWDATA_DOUBLE']
|
|
elif dt == np.float32:
|
|
data = get_reference_data()['FFTWDATA_SINGLE']
|
|
elif dt == np.longfloat:
|
|
data = get_reference_data()['FFTWDATA_LONGDOUBLE']
|
|
else:
|
|
raise ValueError()
|
|
y = (data['dct_%d_%d' % (type, size)]).astype(dt)
|
|
return x, y, dt
|
|
|
|
|
|
def fftw_dst_ref(type, size, dt):
|
|
x = np.linspace(0, size-1, size).astype(dt)
|
|
dt = np.result_type(np.float32, dt)
|
|
if dt == np.double:
|
|
data = get_reference_data()['FFTWDATA_DOUBLE']
|
|
elif dt == np.float32:
|
|
data = get_reference_data()['FFTWDATA_SINGLE']
|
|
elif dt == np.longfloat:
|
|
data = get_reference_data()['FFTWDATA_LONGDOUBLE']
|
|
else:
|
|
raise ValueError()
|
|
y = (data['dst_%d_%d' % (type, size)]).astype(dt)
|
|
return x, y, dt
|
|
|
|
|
|
def ref_2d(func, x, **kwargs):
|
|
"""Calculate 2-D reference data from a 1d transform"""
|
|
x = np.array(x, copy=True)
|
|
for row in range(x.shape[0]):
|
|
x[row, :] = func(x[row, :], **kwargs)
|
|
for col in range(x.shape[1]):
|
|
x[:, col] = func(x[:, col], **kwargs)
|
|
return x
|
|
|
|
|
|
def naive_dct1(x, norm=None):
|
|
"""Calculate textbook definition version of DCT-I."""
|
|
x = np.array(x, copy=True)
|
|
N = len(x)
|
|
M = N-1
|
|
y = np.zeros(N)
|
|
m0, m = 1, 2
|
|
if norm == 'ortho':
|
|
m0 = np.sqrt(1.0/M)
|
|
m = np.sqrt(2.0/M)
|
|
for k in range(N):
|
|
for n in range(1, N-1):
|
|
y[k] += m*x[n]*np.cos(np.pi*n*k/M)
|
|
y[k] += m0 * x[0]
|
|
y[k] += m0 * x[N-1] * (1 if k % 2 == 0 else -1)
|
|
if norm == 'ortho':
|
|
y[0] *= 1/np.sqrt(2)
|
|
y[N-1] *= 1/np.sqrt(2)
|
|
return y
|
|
|
|
|
|
def naive_dst1(x, norm=None):
|
|
"""Calculate textbook definition version of DST-I."""
|
|
x = np.array(x, copy=True)
|
|
N = len(x)
|
|
M = N+1
|
|
y = np.zeros(N)
|
|
for k in range(N):
|
|
for n in range(N):
|
|
y[k] += 2*x[n]*np.sin(np.pi*(n+1.0)*(k+1.0)/M)
|
|
if norm == 'ortho':
|
|
y *= np.sqrt(0.5/M)
|
|
return y
|
|
|
|
|
|
def naive_dct4(x, norm=None):
|
|
"""Calculate textbook definition version of DCT-IV."""
|
|
x = np.array(x, copy=True)
|
|
N = len(x)
|
|
y = np.zeros(N)
|
|
for k in range(N):
|
|
for n in range(N):
|
|
y[k] += x[n]*np.cos(np.pi*(n+0.5)*(k+0.5)/(N))
|
|
if norm == 'ortho':
|
|
y *= np.sqrt(2.0/N)
|
|
else:
|
|
y *= 2
|
|
return y
|
|
|
|
|
|
def naive_dst4(x, norm=None):
|
|
"""Calculate textbook definition version of DST-IV."""
|
|
x = np.array(x, copy=True)
|
|
N = len(x)
|
|
y = np.zeros(N)
|
|
for k in range(N):
|
|
for n in range(N):
|
|
y[k] += x[n]*np.sin(np.pi*(n+0.5)*(k+0.5)/(N))
|
|
if norm == 'ortho':
|
|
y *= np.sqrt(2.0/N)
|
|
else:
|
|
y *= 2
|
|
return y
|
|
|
|
|
|
@pytest.mark.parametrize('dtype', [np.complex64, np.complex128, np.longcomplex])
|
|
@pytest.mark.parametrize('transform', [dct, dst, idct, idst])
|
|
def test_complex(transform, dtype):
|
|
y = transform(1j*np.arange(5, dtype=dtype))
|
|
x = 1j*transform(np.arange(5))
|
|
assert_array_almost_equal(x, y)
|
|
|
|
|
|
# map (tranform, dtype, type) -> decimal
|
|
dec_map = {
|
|
# DCT
|
|
(dct, np.double, 1): 13,
|
|
(dct, np.float32, 1): 6,
|
|
|
|
(dct, np.double, 2): 14,
|
|
(dct, np.float32, 2): 5,
|
|
|
|
(dct, np.double, 3): 14,
|
|
(dct, np.float32, 3): 5,
|
|
|
|
(dct, np.double, 4): 13,
|
|
(dct, np.float32, 4): 6,
|
|
|
|
# IDCT
|
|
(idct, np.double, 1): 14,
|
|
(idct, np.float32, 1): 6,
|
|
|
|
(idct, np.double, 2): 14,
|
|
(idct, np.float32, 2): 5,
|
|
|
|
(idct, np.double, 3): 14,
|
|
(idct, np.float32, 3): 5,
|
|
|
|
(idct, np.double, 4): 14,
|
|
(idct, np.float32, 4): 6,
|
|
|
|
# DST
|
|
(dst, np.double, 1): 13,
|
|
(dst, np.float32, 1): 6,
|
|
|
|
(dst, np.double, 2): 14,
|
|
(dst, np.float32, 2): 6,
|
|
|
|
(dst, np.double, 3): 14,
|
|
(dst, np.float32, 3): 7,
|
|
|
|
(dst, np.double, 4): 13,
|
|
(dst, np.float32, 4): 6,
|
|
|
|
# IDST
|
|
(idst, np.double, 1): 14,
|
|
(idst, np.float32, 1): 6,
|
|
|
|
(idst, np.double, 2): 14,
|
|
(idst, np.float32, 2): 6,
|
|
|
|
(idst, np.double, 3): 14,
|
|
(idst, np.float32, 3): 6,
|
|
|
|
(idst, np.double, 4): 14,
|
|
(idst, np.float32, 4): 6,
|
|
}
|
|
|
|
for k,v in dec_map.copy().items():
|
|
if k[1] == np.double:
|
|
dec_map[(k[0], np.longdouble, k[2])] = v
|
|
elif k[1] == np.float32:
|
|
dec_map[(k[0], int, k[2])] = v
|
|
|
|
|
|
@pytest.mark.parametrize('rdt', [np.longfloat, np.double, np.float32, int])
|
|
@pytest.mark.parametrize('type', [1, 2, 3, 4])
|
|
class TestDCT:
|
|
def test_definition(self, rdt, type, fftwdata_size):
|
|
x, yr, dt = fftw_dct_ref(type, fftwdata_size, rdt)
|
|
y = dct(x, type=type)
|
|
assert_equal(y.dtype, dt)
|
|
dec = dec_map[(dct, rdt, type)]
|
|
assert_allclose(y, yr, rtol=0., atol=np.max(yr)*10**(-dec))
|
|
|
|
@pytest.mark.parametrize('size', [7, 8, 9, 16, 32, 64])
|
|
def test_axis(self, rdt, type, size):
|
|
nt = 2
|
|
dec = dec_map[(dct, rdt, type)]
|
|
x = np.random.randn(nt, size)
|
|
y = dct(x, type=type)
|
|
for j in range(nt):
|
|
assert_array_almost_equal(y[j], dct(x[j], type=type),
|
|
decimal=dec)
|
|
|
|
x = x.T
|
|
y = dct(x, axis=0, type=type)
|
|
for j in range(nt):
|
|
assert_array_almost_equal(y[:,j], dct(x[:,j], type=type),
|
|
decimal=dec)
|
|
|
|
|
|
@pytest.mark.parametrize('rdt', [np.longfloat, np.double, np.float32, int])
|
|
def test_dct1_definition_ortho(rdt, mdata_x):
|
|
# Test orthornomal mode.
|
|
dec = dec_map[(dct, rdt, 1)]
|
|
x = np.array(mdata_x, dtype=rdt)
|
|
dt = np.result_type(np.float32, rdt)
|
|
y = dct(x, norm='ortho', type=1)
|
|
y2 = naive_dct1(x, norm='ortho')
|
|
assert_equal(y.dtype, dt)
|
|
assert_allclose(y, y2, rtol=0., atol=np.max(y2)*10**(-dec))
|
|
|
|
|
|
@pytest.mark.parametrize('rdt', [np.longfloat, np.double, np.float32, int])
|
|
def test_dct2_definition_matlab(mdata_xy, rdt):
|
|
# Test correspondence with matlab (orthornomal mode).
|
|
dt = np.result_type(np.float32, rdt)
|
|
x = np.array(mdata_xy[0], dtype=dt)
|
|
|
|
yr = mdata_xy[1]
|
|
y = dct(x, norm="ortho", type=2)
|
|
dec = dec_map[(dct, rdt, 2)]
|
|
assert_equal(y.dtype, dt)
|
|
assert_array_almost_equal(y, yr, decimal=dec)
|
|
|
|
|
|
@pytest.mark.parametrize('rdt', [np.longfloat, np.double, np.float32, int])
|
|
def test_dct3_definition_ortho(mdata_x, rdt):
|
|
# Test orthornomal mode.
|
|
x = np.array(mdata_x, dtype=rdt)
|
|
dt = np.result_type(np.float32, rdt)
|
|
y = dct(x, norm='ortho', type=2)
|
|
xi = dct(y, norm="ortho", type=3)
|
|
dec = dec_map[(dct, rdt, 3)]
|
|
assert_equal(xi.dtype, dt)
|
|
assert_array_almost_equal(xi, x, decimal=dec)
|
|
|
|
|
|
@pytest.mark.parametrize('rdt', [np.longfloat, np.double, np.float32, int])
|
|
def test_dct4_definition_ortho(mdata_x, rdt):
|
|
# Test orthornomal mode.
|
|
x = np.array(mdata_x, dtype=rdt)
|
|
dt = np.result_type(np.float32, rdt)
|
|
y = dct(x, norm='ortho', type=4)
|
|
y2 = naive_dct4(x, norm='ortho')
|
|
dec = dec_map[(dct, rdt, 4)]
|
|
assert_equal(y.dtype, dt)
|
|
assert_allclose(y, y2, rtol=0., atol=np.max(y2)*10**(-dec))
|
|
|
|
|
|
@pytest.mark.parametrize('rdt', [np.longfloat, np.double, np.float32, int])
|
|
@pytest.mark.parametrize('type', [1, 2, 3, 4])
|
|
def test_idct_definition(fftwdata_size, rdt, type):
|
|
xr, yr, dt = fftw_dct_ref(type, fftwdata_size, rdt)
|
|
x = idct(yr, type=type)
|
|
dec = dec_map[(idct, rdt, type)]
|
|
assert_equal(x.dtype, dt)
|
|
assert_allclose(x, xr, rtol=0., atol=np.max(xr)*10**(-dec))
|
|
|
|
|
|
@pytest.mark.parametrize('rdt', [np.longfloat, np.double, np.float32, int])
|
|
@pytest.mark.parametrize('type', [1, 2, 3, 4])
|
|
def test_definition(fftwdata_size, rdt, type):
|
|
xr, yr, dt = fftw_dst_ref(type, fftwdata_size, rdt)
|
|
y = dst(xr, type=type)
|
|
dec = dec_map[(dst, rdt, type)]
|
|
assert_equal(y.dtype, dt)
|
|
assert_allclose(y, yr, rtol=0., atol=np.max(yr)*10**(-dec))
|
|
|
|
|
|
@pytest.mark.parametrize('rdt', [np.longfloat, np.double, np.float32, int])
|
|
def test_dst1_definition_ortho(rdt, mdata_x):
|
|
# Test orthornomal mode.
|
|
dec = dec_map[(dst, rdt, 1)]
|
|
x = np.array(mdata_x, dtype=rdt)
|
|
dt = np.result_type(np.float32, rdt)
|
|
y = dst(x, norm='ortho', type=1)
|
|
y2 = naive_dst1(x, norm='ortho')
|
|
assert_equal(y.dtype, dt)
|
|
assert_allclose(y, y2, rtol=0., atol=np.max(y2)*10**(-dec))
|
|
|
|
|
|
@pytest.mark.parametrize('rdt', [np.longfloat, np.double, np.float32, int])
|
|
def test_dst4_definition_ortho(rdt, mdata_x):
|
|
# Test orthornomal mode.
|
|
dec = dec_map[(dst, rdt, 4)]
|
|
x = np.array(mdata_x, dtype=rdt)
|
|
dt = np.result_type(np.float32, rdt)
|
|
y = dst(x, norm='ortho', type=4)
|
|
y2 = naive_dst4(x, norm='ortho')
|
|
assert_equal(y.dtype, dt)
|
|
assert_array_almost_equal(y, y2, decimal=dec)
|
|
|
|
|
|
@pytest.mark.parametrize('rdt', [np.longfloat, np.double, np.float32, int])
|
|
@pytest.mark.parametrize('type', [1, 2, 3, 4])
|
|
def test_idst_definition(fftwdata_size, rdt, type):
|
|
xr, yr, dt = fftw_dst_ref(type, fftwdata_size, rdt)
|
|
x = idst(yr, type=type)
|
|
dec = dec_map[(idst, rdt, type)]
|
|
assert_equal(x.dtype, dt)
|
|
assert_allclose(x, xr, rtol=0., atol=np.max(xr)*10**(-dec))
|
|
|
|
|
|
@pytest.mark.parametrize('routine', [dct, dst, idct, idst])
|
|
@pytest.mark.parametrize('dtype', [np.float32, np.float64, np.longfloat])
|
|
@pytest.mark.parametrize('shape, axis', [
|
|
((16,), -1), ((16, 2), 0), ((2, 16), 1)
|
|
])
|
|
@pytest.mark.parametrize('type', [1, 2, 3, 4])
|
|
@pytest.mark.parametrize('overwrite_x', [True, False])
|
|
@pytest.mark.parametrize('norm', [None, 'ortho'])
|
|
def test_overwrite(routine, dtype, shape, axis, type, norm, overwrite_x):
|
|
# Check input overwrite behavior
|
|
np.random.seed(1234)
|
|
if np.issubdtype(dtype, np.complexfloating):
|
|
x = np.random.randn(*shape) + 1j*np.random.randn(*shape)
|
|
else:
|
|
x = np.random.randn(*shape)
|
|
x = x.astype(dtype)
|
|
x2 = x.copy()
|
|
routine(x2, type, None, axis, norm, overwrite_x=overwrite_x)
|
|
|
|
sig = "%s(%s%r, %r, axis=%r, overwrite_x=%r)" % (
|
|
routine.__name__, x.dtype, x.shape, None, axis, overwrite_x)
|
|
if not overwrite_x:
|
|
assert_equal(x2, x, err_msg="spurious overwrite in %s" % sig)
|
|
|
|
|
|
class Test_DCTN_IDCTN(object):
|
|
dec = 14
|
|
dct_type = [1, 2, 3, 4]
|
|
norms = [None, 'backward', 'ortho', 'forward']
|
|
rstate = np.random.RandomState(1234)
|
|
shape = (32, 16)
|
|
data = rstate.randn(*shape)
|
|
|
|
@pytest.mark.parametrize('fforward,finverse', [(dctn, idctn),
|
|
(dstn, idstn)])
|
|
@pytest.mark.parametrize('axes', [None,
|
|
1, (1,), [1],
|
|
0, (0,), [0],
|
|
(0, 1), [0, 1],
|
|
(-2, -1), [-2, -1]])
|
|
@pytest.mark.parametrize('dct_type', dct_type)
|
|
@pytest.mark.parametrize('norm', ['ortho'])
|
|
def test_axes_round_trip(self, fforward, finverse, axes, dct_type, norm):
|
|
tmp = fforward(self.data, type=dct_type, axes=axes, norm=norm)
|
|
tmp = finverse(tmp, type=dct_type, axes=axes, norm=norm)
|
|
assert_array_almost_equal(self.data, tmp, decimal=12)
|
|
|
|
@pytest.mark.parametrize('funcn,func', [(dctn, dct), (dstn, dst)])
|
|
@pytest.mark.parametrize('dct_type', dct_type)
|
|
@pytest.mark.parametrize('norm', norms)
|
|
def test_dctn_vs_2d_reference(self, funcn, func, dct_type, norm):
|
|
y1 = funcn(self.data, type=dct_type, axes=None, norm=norm)
|
|
y2 = ref_2d(func, self.data, type=dct_type, norm=norm)
|
|
assert_array_almost_equal(y1, y2, decimal=11)
|
|
|
|
@pytest.mark.parametrize('funcn,func', [(idctn, idct), (idstn, idst)])
|
|
@pytest.mark.parametrize('dct_type', dct_type)
|
|
@pytest.mark.parametrize('norm', norms)
|
|
def test_idctn_vs_2d_reference(self, funcn, func, dct_type, norm):
|
|
fdata = dctn(self.data, type=dct_type, norm=norm)
|
|
y1 = funcn(fdata, type=dct_type, norm=norm)
|
|
y2 = ref_2d(func, fdata, type=dct_type, norm=norm)
|
|
assert_array_almost_equal(y1, y2, decimal=11)
|
|
|
|
@pytest.mark.parametrize('fforward,finverse', [(dctn, idctn),
|
|
(dstn, idstn)])
|
|
def test_axes_and_shape(self, fforward, finverse):
|
|
with assert_raises(ValueError,
|
|
match="when given, axes and shape arguments"
|
|
" have to be of the same length"):
|
|
fforward(self.data, s=self.data.shape[0], axes=(0, 1))
|
|
|
|
with assert_raises(ValueError,
|
|
match="when given, axes and shape arguments"
|
|
" have to be of the same length"):
|
|
fforward(self.data, s=self.data.shape, axes=0)
|
|
|
|
@pytest.mark.parametrize('fforward', [dctn, dstn])
|
|
def test_shape(self, fforward):
|
|
tmp = fforward(self.data, s=(128, 128), axes=None)
|
|
assert_equal(tmp.shape, (128, 128))
|
|
|
|
@pytest.mark.parametrize('fforward,finverse', [(dctn, idctn),
|
|
(dstn, idstn)])
|
|
@pytest.mark.parametrize('axes', [1, (1,), [1],
|
|
0, (0,), [0]])
|
|
def test_shape_is_none_with_axes(self, fforward, finverse, axes):
|
|
tmp = fforward(self.data, s=None, axes=axes, norm='ortho')
|
|
tmp = finverse(tmp, s=None, axes=axes, norm='ortho')
|
|
assert_array_almost_equal(self.data, tmp, decimal=self.dec)
|
|
|
|
|
|
@pytest.mark.parametrize('func', [dct, dctn, idct, idctn,
|
|
dst, dstn, idst, idstn])
|
|
def test_swapped_byte_order(func):
|
|
rng = np.random.RandomState(1234)
|
|
x = rng.rand(10)
|
|
swapped_dt = x.dtype.newbyteorder('S')
|
|
assert_allclose(func(x.astype(swapped_dt)), func(x))
|