You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

482 lines
14 KiB
Python

import numpy as np
from scipy.linalg import eig
from scipy.special import comb
from scipy.signal import convolve
__all__ = ['daub', 'qmf', 'cascade', 'morlet', 'ricker', 'morlet2', 'cwt']
def daub(p):
"""
The coefficients for the FIR low-pass filter producing Daubechies wavelets.
p>=1 gives the order of the zero at f=1/2.
There are 2p filter coefficients.
Parameters
----------
p : int
Order of the zero at f=1/2, can have values from 1 to 34.
Returns
-------
daub : ndarray
Return
"""
sqrt = np.sqrt
if p < 1:
raise ValueError("p must be at least 1.")
if p == 1:
c = 1 / sqrt(2)
return np.array([c, c])
elif p == 2:
f = sqrt(2) / 8
c = sqrt(3)
return f * np.array([1 + c, 3 + c, 3 - c, 1 - c])
elif p == 3:
tmp = 12 * sqrt(10)
z1 = 1.5 + sqrt(15 + tmp) / 6 - 1j * (sqrt(15) + sqrt(tmp - 15)) / 6
z1c = np.conj(z1)
f = sqrt(2) / 8
d0 = np.real((1 - z1) * (1 - z1c))
a0 = np.real(z1 * z1c)
a1 = 2 * np.real(z1)
return f / d0 * np.array([a0, 3 * a0 - a1, 3 * a0 - 3 * a1 + 1,
a0 - 3 * a1 + 3, 3 - a1, 1])
elif p < 35:
# construct polynomial and factor it
if p < 35:
P = [comb(p - 1 + k, k, exact=1) for k in range(p)][::-1]
yj = np.roots(P)
else: # try different polynomial --- needs work
P = [comb(p - 1 + k, k, exact=1) / 4.0**k
for k in range(p)][::-1]
yj = np.roots(P) / 4
# for each root, compute two z roots, select the one with |z|>1
# Build up final polynomial
c = np.poly1d([1, 1])**p
q = np.poly1d([1])
for k in range(p - 1):
yval = yj[k]
part = 2 * sqrt(yval * (yval - 1))
const = 1 - 2 * yval
z1 = const + part
if (abs(z1)) < 1:
z1 = const - part
q = q * [1, -z1]
q = c * np.real(q)
# Normalize result
q = q / np.sum(q) * sqrt(2)
return q.c[::-1]
else:
raise ValueError("Polynomial factorization does not work "
"well for p too large.")
def qmf(hk):
"""
Return high-pass qmf filter from low-pass
Parameters
----------
hk : array_like
Coefficients of high-pass filter.
"""
N = len(hk) - 1
asgn = [{0: 1, 1: -1}[k % 2] for k in range(N + 1)]
return hk[::-1] * np.array(asgn)
def cascade(hk, J=7):
"""
Return (x, phi, psi) at dyadic points ``K/2**J`` from filter coefficients.
Parameters
----------
hk : array_like
Coefficients of low-pass filter.
J : int, optional
Values will be computed at grid points ``K/2**J``. Default is 7.
Returns
-------
x : ndarray
The dyadic points ``K/2**J`` for ``K=0...N * (2**J)-1`` where
``len(hk) = len(gk) = N+1``.
phi : ndarray
The scaling function ``phi(x)`` at `x`:
``phi(x) = sum(hk * phi(2x-k))``, where k is from 0 to N.
psi : ndarray, optional
The wavelet function ``psi(x)`` at `x`:
``phi(x) = sum(gk * phi(2x-k))``, where k is from 0 to N.
`psi` is only returned if `gk` is not None.
Notes
-----
The algorithm uses the vector cascade algorithm described by Strang and
Nguyen in "Wavelets and Filter Banks". It builds a dictionary of values
and slices for quick reuse. Then inserts vectors into final vector at the
end.
"""
N = len(hk) - 1
if (J > 30 - np.log2(N + 1)):
raise ValueError("Too many levels.")
if (J < 1):
raise ValueError("Too few levels.")
# construct matrices needed
nn, kk = np.ogrid[:N, :N]
s2 = np.sqrt(2)
# append a zero so that take works
thk = np.r_[hk, 0]
gk = qmf(hk)
tgk = np.r_[gk, 0]
indx1 = np.clip(2 * nn - kk, -1, N + 1)
indx2 = np.clip(2 * nn - kk + 1, -1, N + 1)
m = np.empty((2, 2, N, N), 'd')
m[0, 0] = np.take(thk, indx1, 0)
m[0, 1] = np.take(thk, indx2, 0)
m[1, 0] = np.take(tgk, indx1, 0)
m[1, 1] = np.take(tgk, indx2, 0)
m *= s2
# construct the grid of points
x = np.arange(0, N * (1 << J), dtype=float) / (1 << J)
phi = 0 * x
psi = 0 * x
# find phi0, and phi1
lam, v = eig(m[0, 0])
ind = np.argmin(np.absolute(lam - 1))
# a dictionary with a binary representation of the
# evaluation points x < 1 -- i.e. position is 0.xxxx
v = np.real(v[:, ind])
# need scaling function to integrate to 1 so find
# eigenvector normalized to sum(v,axis=0)=1
sm = np.sum(v)
if sm < 0: # need scaling function to integrate to 1
v = -v
sm = -sm
bitdic = {'0': v / sm}
bitdic['1'] = np.dot(m[0, 1], bitdic['0'])
step = 1 << J
phi[::step] = bitdic['0']
phi[(1 << (J - 1))::step] = bitdic['1']
psi[::step] = np.dot(m[1, 0], bitdic['0'])
psi[(1 << (J - 1))::step] = np.dot(m[1, 1], bitdic['0'])
# descend down the levels inserting more and more values
# into bitdic -- store the values in the correct location once we
# have computed them -- stored in the dictionary
# for quicker use later.
prevkeys = ['1']
for level in range(2, J + 1):
newkeys = ['%d%s' % (xx, yy) for xx in [0, 1] for yy in prevkeys]
fac = 1 << (J - level)
for key in newkeys:
# convert key to number
num = 0
for pos in range(level):
if key[pos] == '1':
num += (1 << (level - 1 - pos))
pastphi = bitdic[key[1:]]
ii = int(key[0])
temp = np.dot(m[0, ii], pastphi)
bitdic[key] = temp
phi[num * fac::step] = temp
psi[num * fac::step] = np.dot(m[1, ii], pastphi)
prevkeys = newkeys
return x, phi, psi
def morlet(M, w=5.0, s=1.0, complete=True):
"""
Complex Morlet wavelet.
Parameters
----------
M : int
Length of the wavelet.
w : float, optional
Omega0. Default is 5
s : float, optional
Scaling factor, windowed from ``-s*2*pi`` to ``+s*2*pi``. Default is 1.
complete : bool, optional
Whether to use the complete or the standard version.
Returns
-------
morlet : (M,) ndarray
See Also
--------
morlet2 : Implementation of Morlet wavelet, compatible with `cwt`.
scipy.signal.gausspulse
Notes
-----
The standard version::
pi**-0.25 * exp(1j*w*x) * exp(-0.5*(x**2))
This commonly used wavelet is often referred to simply as the
Morlet wavelet. Note that this simplified version can cause
admissibility problems at low values of `w`.
The complete version::
pi**-0.25 * (exp(1j*w*x) - exp(-0.5*(w**2))) * exp(-0.5*(x**2))
This version has a correction
term to improve admissibility. For `w` greater than 5, the
correction term is negligible.
Note that the energy of the return wavelet is not normalised
according to `s`.
The fundamental frequency of this wavelet in Hz is given
by ``f = 2*s*w*r / M`` where `r` is the sampling rate.
Note: This function was created before `cwt` and is not compatible
with it.
"""
x = np.linspace(-s * 2 * np.pi, s * 2 * np.pi, M)
output = np.exp(1j * w * x)
if complete:
output -= np.exp(-0.5 * (w**2))
output *= np.exp(-0.5 * (x**2)) * np.pi**(-0.25)
return output
def ricker(points, a):
"""
Return a Ricker wavelet, also known as the "Mexican hat wavelet".
It models the function:
``A * (1 - (x/a)**2) * exp(-0.5*(x/a)**2)``,
where ``A = 2/(sqrt(3*a)*(pi**0.25))``.
Parameters
----------
points : int
Number of points in `vector`.
Will be centered around 0.
a : scalar
Width parameter of the wavelet.
Returns
-------
vector : (N,) ndarray
Array of length `points` in shape of ricker curve.
Examples
--------
>>> from scipy import signal
>>> import matplotlib.pyplot as plt
>>> points = 100
>>> a = 4.0
>>> vec2 = signal.ricker(points, a)
>>> print(len(vec2))
100
>>> plt.plot(vec2)
>>> plt.show()
"""
A = 2 / (np.sqrt(3 * a) * (np.pi**0.25))
wsq = a**2
vec = np.arange(0, points) - (points - 1.0) / 2
xsq = vec**2
mod = (1 - xsq / wsq)
gauss = np.exp(-xsq / (2 * wsq))
total = A * mod * gauss
return total
def morlet2(M, s, w=5):
"""
Complex Morlet wavelet, designed to work with `cwt`.
Returns the complete version of morlet wavelet, normalised
according to `s`::
exp(1j*w*x/s) * exp(-0.5*(x/s)**2) * pi**(-0.25) * sqrt(1/s)
Parameters
----------
M : int
Length of the wavelet.
s : float
Width parameter of the wavelet.
w : float, optional
Omega0. Default is 5
Returns
-------
morlet : (M,) ndarray
See Also
--------
morlet : Implementation of Morlet wavelet, incompatible with `cwt`
Notes
-----
.. versionadded:: 1.4.0
This function was designed to work with `cwt`. Because `morlet2`
returns an array of complex numbers, the `dtype` argument of `cwt`
should be set to `complex128` for best results.
Note the difference in implementation with `morlet`.
The fundamental frequency of this wavelet in Hz is given by::
f = w*fs / (2*s*np.pi)
where ``fs`` is the sampling rate and `s` is the wavelet width parameter.
Similarly we can get the wavelet width parameter at ``f``::
s = w*fs / (2*f*np.pi)
Examples
--------
>>> from scipy import signal
>>> import matplotlib.pyplot as plt
>>> M = 100
>>> s = 4.0
>>> w = 2.0
>>> wavelet = signal.morlet2(M, s, w)
>>> plt.plot(abs(wavelet))
>>> plt.show()
This example shows basic use of `morlet2` with `cwt` in time-frequency
analysis:
>>> from scipy import signal
>>> import matplotlib.pyplot as plt
>>> t, dt = np.linspace(0, 1, 200, retstep=True)
>>> fs = 1/dt
>>> w = 6.
>>> sig = np.cos(2*np.pi*(50 + 10*t)*t) + np.sin(40*np.pi*t)
>>> freq = np.linspace(1, fs/2, 100)
>>> widths = w*fs / (2*freq*np.pi)
>>> cwtm = signal.cwt(sig, signal.morlet2, widths, w=w)
>>> plt.pcolormesh(t, freq, np.abs(cwtm), cmap='viridis', shading='gouraud')
>>> plt.show()
"""
x = np.arange(0, M) - (M - 1.0) / 2
x = x / s
wavelet = np.exp(1j * w * x) * np.exp(-0.5 * x**2) * np.pi**(-0.25)
output = np.sqrt(1/s) * wavelet
return output
def cwt(data, wavelet, widths, dtype=None, **kwargs):
"""
Continuous wavelet transform.
Performs a continuous wavelet transform on `data`,
using the `wavelet` function. A CWT performs a convolution
with `data` using the `wavelet` function, which is characterized
by a width parameter and length parameter. The `wavelet` function
is allowed to be complex.
Parameters
----------
data : (N,) ndarray
data on which to perform the transform.
wavelet : function
Wavelet function, which should take 2 arguments.
The first argument is the number of points that the returned vector
will have (len(wavelet(length,width)) == length).
The second is a width parameter, defining the size of the wavelet
(e.g. standard deviation of a gaussian). See `ricker`, which
satisfies these requirements.
widths : (M,) sequence
Widths to use for transform.
dtype : data-type, optional
The desired data type of output. Defaults to ``float64`` if the
output of `wavelet` is real and ``complex128`` if it is complex.
.. versionadded:: 1.4.0
kwargs
Keyword arguments passed to wavelet function.
.. versionadded:: 1.4.0
Returns
-------
cwt: (M, N) ndarray
Will have shape of (len(widths), len(data)).
Notes
-----
.. versionadded:: 1.4.0
For non-symmetric, complex-valued wavelets, the input signal is convolved
with the time-reversed complex-conjugate of the wavelet data [1].
::
length = min(10 * width[ii], len(data))
cwt[ii,:] = signal.convolve(data, np.conj(wavelet(length, width[ii],
**kwargs))[::-1], mode='same')
References
----------
.. [1] S. Mallat, "A Wavelet Tour of Signal Processing (3rd Edition)",
Academic Press, 2009.
Examples
--------
>>> from scipy import signal
>>> import matplotlib.pyplot as plt
>>> t = np.linspace(-1, 1, 200, endpoint=False)
>>> sig = np.cos(2 * np.pi * 7 * t) + signal.gausspulse(t - 0.4, fc=2)
>>> widths = np.arange(1, 31)
>>> cwtmatr = signal.cwt(sig, signal.ricker, widths)
>>> plt.imshow(cwtmatr, extent=[-1, 1, 1, 31], cmap='PRGn', aspect='auto',
... vmax=abs(cwtmatr).max(), vmin=-abs(cwtmatr).max())
>>> plt.show()
"""
if wavelet == ricker:
window_size = kwargs.pop('window_size', None)
# Determine output type
if dtype is None:
if np.asarray(wavelet(1, widths[0], **kwargs)).dtype.char in 'FDG':
dtype = np.complex128
else:
dtype = np.float64
output = np.empty((len(widths), len(data)), dtype=dtype)
for ind, width in enumerate(widths):
N = np.min([10 * width, len(data)])
# the conditional block below and the window_size
# kwarg pop above may be removed eventually; these
# are shims for 32-bit arch + NumPy <= 1.14.5 to
# address gh-11095
if wavelet == ricker and window_size is None:
ceil = np.ceil(N)
if ceil != N:
N = int(N)
wavelet_data = np.conj(wavelet(N, width, **kwargs)[::-1])
output[ind] = convolve(data, wavelet_data, mode='same')
return output