Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from functools import partial, reduce
- import numpy as np
- import pywt
- from matplotlib import pyplot as plt
- def _rescale_wavelet_filterbank(wavelet, sf):
- from pywt._utils import _as_wavelet
- wavelet = _as_wavelet(wavelet) # convert string to pywt.Wavelet if needed
- wav = pywt.Wavelet(wavelet.name + 'r',
- [np.asarray(f) * sf for f in wavelet.filter_bank])
- # copy attributes from the original wavelet
- wav.orthogonal = wavelet.orthogonal
- wav.biorthogonal = wavelet.biorthogonal
- return wav
- def mra(data, wavelet='sym4', level=None, transform='swt'):
- """Forward 1D multiresolution analysis.
- This is also known as an additive decomposition.
- Parameters
- ----------
- data: array_like
- Input data
- wavelet : Wavelet object or name string
- Wavelet to use
- level : int, optional
- Decomposition level (must be >= 0). If level is None (default) then it
- will be calculated using the `dwt_max_level` function.
- transform : {'dwt', 'swt'}
- Whether to use the DWT or SWT for the transforms.
- Returns
- -------
- [cAn, {details_level_n}, ... {details_level_1}] : list
- """
- if transform == 'swt':
- # use different normalization
- kwargs = dict(wavelet=wavelet)
- def forward(data, wavelet=wavelet, level=level):
- """Forward SWT without redundant coefficients."""
- coeffs = pywt.swt(data, wavelet=wavelet, level=level)
- # discarded uneeded approximation coeffs
- coeffs = [coeffs[0][0], coeffs[0][1]] + [c[1] for c in coeffs[1:]]
- return coeffs
- def inverse(coeffs, wavelet=wavelet):
- """Inverse SWT without redundant coefficients."""
- coeffs = [(coeffs[0], coeffs[1])] + [(np.zeros_like(c), c)
- for c in coeffs[2:]]
- rec = pywt.iswt(coeffs, wavelet=wavelet)
- return rec
- elif transform == 'dwt':
- kwargs = dict(wavelet=wavelet, mode='periodization')
- forward = partial(pywt.wavedec, level=level, **kwargs)
- inverse = partial(pywt.waverec, **kwargs)
- wav_coeffs = forward(data)
- # compute MRA projections
- mra_coeffs = []
- nc = len(wav_coeffs)
- tmp = [np.zeros_like(c) for c in wav_coeffs]
- for j in range(nc):
- # tmp has arrays of zeros except for the jth entry
- tmp[j] = wav_coeffs[j]
- # reconstruct
- rec = inverse(tmp)
- if rec.shape != data.shape:
- # trim any excess coefficients
- rec = rec[tuple([slice(sz) for sz in data.shape])]
- mra_coeffs.append(rec)
- # restore zeros
- tmp[j] = np.zeros_like(tmp[j])
- return mra_coeffs
- def imra(mra_coeffs):
- """Inverse 1D multiresolution analysis via summation.
- Parameters
- ----------
- mra_coeffs : list of ndarray
- Multiresolution analysis coefficients as returned by `mra`.
- Returns
- -------
- rec : ndarray
- The reconstructed signal.
- """
- return reduce(lambda x, y: x + y, mra_coeffs)
- # call the forward transform on demo ECG data
- mra_coeffs = mra(pywt.data.ecg(), wavelet='sym4')
- # plot all MRA coefficients
- plt.figure()
- plt.plot(np.stack(mra_coeffs, axis=-1))
- # reconstruct from the MRA coefficients
- r = imra(mra_coeffs)
- plt.figure()
- plt.plot(r)
- plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement