Advertisement
Guest User

Untitled

a guest
Mar 21st, 2019
62
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.29 KB | None | 0 0
  1. from functools import partial, reduce
  2.  
  3. import numpy as np
  4.  
  5. import pywt
  6. from matplotlib import pyplot as plt
  7.  
  8.  
  9. def _rescale_wavelet_filterbank(wavelet, sf):
  10. from pywt._utils import _as_wavelet
  11. wavelet = _as_wavelet(wavelet) # convert string to pywt.Wavelet if needed
  12. wav = pywt.Wavelet(wavelet.name + 'r',
  13. [np.asarray(f) * sf for f in wavelet.filter_bank])
  14.  
  15. # copy attributes from the original wavelet
  16. wav.orthogonal = wavelet.orthogonal
  17. wav.biorthogonal = wavelet.biorthogonal
  18. return wav
  19.  
  20.  
  21. def mra(data, wavelet='sym4', level=None, transform='swt'):
  22. """Forward 1D multiresolution analysis.
  23.  
  24. This is also known as an additive decomposition.
  25.  
  26. Parameters
  27. ----------
  28. data: array_like
  29. Input data
  30. wavelet : Wavelet object or name string
  31. Wavelet to use
  32. level : int, optional
  33. Decomposition level (must be >= 0). If level is None (default) then it
  34. will be calculated using the `dwt_max_level` function.
  35. transform : {'dwt', 'swt'}
  36. Whether to use the DWT or SWT for the transforms.
  37.  
  38. Returns
  39. -------
  40. [cAn, {details_level_n}, ... {details_level_1}] : list
  41. """
  42.  
  43. if transform == 'swt':
  44. # use different normalization
  45. kwargs = dict(wavelet=wavelet)
  46.  
  47. def forward(data, wavelet=wavelet, level=level):
  48. """Forward SWT without redundant coefficients."""
  49. coeffs = pywt.swt(data, wavelet=wavelet, level=level)
  50. # discarded uneeded approximation coeffs
  51. coeffs = [coeffs[0][0], coeffs[0][1]] + [c[1] for c in coeffs[1:]]
  52. return coeffs
  53.  
  54. def inverse(coeffs, wavelet=wavelet):
  55. """Inverse SWT without redundant coefficients."""
  56. coeffs = [(coeffs[0], coeffs[1])] + [(np.zeros_like(c), c)
  57. for c in coeffs[2:]]
  58. rec = pywt.iswt(coeffs, wavelet=wavelet)
  59. return rec
  60.  
  61. elif transform == 'dwt':
  62. kwargs = dict(wavelet=wavelet, mode='periodization')
  63. forward = partial(pywt.wavedec, level=level, **kwargs)
  64. inverse = partial(pywt.waverec, **kwargs)
  65.  
  66. wav_coeffs = forward(data)
  67.  
  68. # compute MRA projections
  69. mra_coeffs = []
  70. nc = len(wav_coeffs)
  71. tmp = [np.zeros_like(c) for c in wav_coeffs]
  72. for j in range(nc):
  73. # tmp has arrays of zeros except for the jth entry
  74. tmp[j] = wav_coeffs[j]
  75.  
  76. # reconstruct
  77. rec = inverse(tmp)
  78. if rec.shape != data.shape:
  79. # trim any excess coefficients
  80. rec = rec[tuple([slice(sz) for sz in data.shape])]
  81. mra_coeffs.append(rec)
  82.  
  83. # restore zeros
  84. tmp[j] = np.zeros_like(tmp[j])
  85. return mra_coeffs
  86.  
  87. def imra(mra_coeffs):
  88. """Inverse 1D multiresolution analysis via summation.
  89.  
  90. Parameters
  91. ----------
  92. mra_coeffs : list of ndarray
  93. Multiresolution analysis coefficients as returned by `mra`.
  94.  
  95. Returns
  96. -------
  97. rec : ndarray
  98. The reconstructed signal.
  99. """
  100. return reduce(lambda x, y: x + y, mra_coeffs)
  101.  
  102.  
  103. # call the forward transform on demo ECG data
  104. mra_coeffs = mra(pywt.data.ecg(), wavelet='sym4')
  105.  
  106. # plot all MRA coefficients
  107. plt.figure()
  108. plt.plot(np.stack(mra_coeffs, axis=-1))
  109.  
  110. # reconstruct from the MRA coefficients
  111. r = imra(mra_coeffs)
  112. plt.figure()
  113. plt.plot(r)
  114.  
  115. plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement