You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

111 lines
3.2 KiB
Python

import numpy as np
from . import pypocketfft as pfft
from .helper import (_asfarray, _init_nd_shape_and_axes, _datacopied,
_fix_shape, _fix_shape_1d, _normalization, _workers)
import functools
def _r2r(forward, transform, x, type=2, n=None, axis=-1, norm=None,
overwrite_x=False, workers=None):
"""Forward or backward 1-D DCT/DST
Parameters
----------
forward: bool
Transform direction (determines type and normalisation)
transform: {pypocketfft.dct, pypocketfft.dst}
The transform to perform
"""
tmp = _asfarray(x)
overwrite_x = overwrite_x or _datacopied(tmp, x)
norm = _normalization(norm, forward)
workers = _workers(workers)
if not forward:
if type == 2:
type = 3
elif type == 3:
type = 2
if n is not None:
tmp, copied = _fix_shape_1d(tmp, n, axis)
overwrite_x = overwrite_x or copied
elif tmp.shape[axis] < 1:
raise ValueError("invalid number of data points ({0}) specified"
.format(tmp.shape[axis]))
out = (tmp if overwrite_x else None)
# For complex input, transform real and imaginary components separably
if np.iscomplexobj(x):
out = np.empty_like(tmp) if out is None else out
transform(tmp.real, type, (axis,), norm, out.real, workers)
transform(tmp.imag, type, (axis,), norm, out.imag, workers)
return out
return transform(tmp, type, (axis,), norm, out, workers)
dct = functools.partial(_r2r, True, pfft.dct)
dct.__name__ = 'dct'
idct = functools.partial(_r2r, False, pfft.dct)
idct.__name__ = 'idct'
dst = functools.partial(_r2r, True, pfft.dst)
dst.__name__ = 'dst'
idst = functools.partial(_r2r, False, pfft.dst)
idst.__name__ = 'idst'
def _r2rn(forward, transform, x, type=2, s=None, axes=None, norm=None,
overwrite_x=False, workers=None):
"""Forward or backward nd DCT/DST
Parameters
----------
forward: bool
Transform direction (determines type and normalisation)
transform: {pypocketfft.dct, pypocketfft.dst}
The transform to perform
"""
tmp = _asfarray(x)
shape, axes = _init_nd_shape_and_axes(tmp, s, axes)
overwrite_x = overwrite_x or _datacopied(tmp, x)
if len(axes) == 0:
return x
tmp, copied = _fix_shape(tmp, shape, axes)
overwrite_x = overwrite_x or copied
if not forward:
if type == 2:
type = 3
elif type == 3:
type = 2
norm = _normalization(norm, forward)
workers = _workers(workers)
out = (tmp if overwrite_x else None)
# For complex input, transform real and imaginary components separably
if np.iscomplexobj(x):
out = np.empty_like(tmp) if out is None else out
transform(tmp.real, type, axes, norm, out.real, workers)
transform(tmp.imag, type, axes, norm, out.imag, workers)
return out
return transform(tmp, type, axes, norm, out, workers)
dctn = functools.partial(_r2rn, True, pfft.dct)
dctn.__name__ = 'dctn'
idctn = functools.partial(_r2rn, False, pfft.dct)
idctn.__name__ = 'idctn'
dstn = functools.partial(_r2rn, True, pfft.dst)
dstn.__name__ = 'dstn'
idstn = functools.partial(_r2rn, False, pfft.dst)
idstn.__name__ = 'idstn'