Untitled

a guest
Dec 27th, 2020
149
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
1. # -*- coding: utf-8 -*-
2. import numpy as np
3. from numpy.fft import fft, ifft, ifftshift
4. from ssqueezepy.utils import padsignal, process_scales
5. from ssqueezepy.algos import find_closest, indexed_sum
6. from ssqueezepy.visuals import imshow, plot, scat
7. from ssqueezepy import Wavelet
8.
9. #%%# Helpers #################################################################
10. def _t(min, max, N):
11.     return np.linspace(min, max, N, endpoint=1)
12.
13. def cos_f(freqs, N=128, phi=0):
14.     return np.concatenate([np.cos(2 * np.pi * f * (_t(i, i + 1, N) + phi))
15.                            for i, f in enumerate(freqs)])
16.
17. #%%# Define signal & wavelet #################################################
18. fs = 129
19. x = cos_f([8], fs)
20. plot(x, title="Pure sine | N=129, f=8")
21. scat(x, show=1)
22.
23. wavelet = Wavelet(('morlet', {'mu': 4}))
24.
25. #%%# CWT #####################################################################
26. nv=32; dt=1/fs
27.
28. n = len(x)  # store original length
30. x -= x.mean()
31. xh = fft(x)
32.
33. scales = process_scales(scales='log:maximal', len_x=n, wavelet=wavelet, nv=nv)
34. pn = (-1)**np.arange(nup)
35.
36. N_orig = wavelet.N
37. wavelet.N = nup
38.
39. #%%# cwt ####
40. Psih = (wavelet(scale=scales, nohalf=False)).astype('complex128')
41. dPsih = (1j * wavelet.xi / dt) * Psih
42.
43. Wx  = ifftshift(ifft(pn * Psih  * xh, axis=-1), axes=-1)
44. dWx = ifftshift(ifft(pn * dPsih * xh, axis=-1), axes=-1)
45. #%%#
46. wavelet.N = N_orig
47. # shorten to pre-padded size
48. Wx  = Wx[:,  n1:n1 + n]
49. dWx = dWx[:, n1:n1 + n]
50.
51. #%%# Phase transform #########################################################
52. w = np.imag(dWx / Wx) / (2*np.pi)
53.
54. # clean up tiny-valued Wx that have large `w` values; removing these makes
55. # no noticeable difference on `Tx` but allows us to see much better
56. w[np.abs(Wx) < np.abs(Wx).mean()] = 0
57.
58. #%%# Reassignment frequencies (mapkind='maximal') ############################
59. na, N = Wx.shape
60. dT = dt * N
61. # normalized frequencies to map discrete-domain to physical:
62. #     f[[cycles/samples]] -> f[[cycles/second]]
63. # minimum measurable (fundamental) frequency of data
64. fm = 1 / dT
65. # maximum measurable (Nyquist) frequency of data
66. fM = 1 / (2 * dt)
67.
68. ssq_freqs = fm * np.power(fM / fm, np.arange(na) / (na - 1))
69.
70. #%%# Reassignment indices
71. # This step simply finds the index-equivalent of `w`. E.g., for given
72. # `ssq_freqs` ranging from 2 to 48, if w[5, 2] == 5, then
73. # `k[5, 2] = np.where(ssq_freqs == 5)` (or if no exact match, then closest to 5)
74. # `k` thus ranges from 0 to `len(ssq_freqs) - 1`.
75. k = find_closest(np.log2(w), np.log2(ssq_freqs))
76.
77. #%%# Synchrosqueeze #########################################################
78. Tx = indexed_sum(Wx * np.log(2) / nv, k)
79. Tx = np.flipud(Tx)  # flip for visual aligned with `Wx`
80.
81. #%%# Visualize ##############################################################
82. kw = dict(abs=1, cmap='jet', show=1, aspect='auto')
83. imshow(Wx, title="abs(CWT)", ylabel="scales", yticks=scales, **kw)
84. imshow(Tx, title="abs(SSQ_CWT)", ylabel="frequencies", yticks=ssq_freqs, **kw)
85. #%%# Zoom ####
86. a, b = 50, 158
87. c, d = 0, None
88. idxs = np.arange(len(scales)).astype('int64')
89. imshow(Wx[a:b, c:d], title="abs(CWT), zoomed", yticks=scales[a:b], **kw)
90. imshow(Tx[a:b, c:d], title="abs(SSQ_CWT), zoomed", yticks=idxs[a:b], **kw)
91.
92. #%%
93. plot(w[81:120, 0], xticks=scales[81:120], show=1,
94.      title="w[81:120, 0] | Phase transform across zoomed scales")
95. #%%# Repeat for `w` ####
96. imshow(w, title="Phase transform | (min, max) = (%.3f, %.3f)" % (w.min(), w.max()),
97.        ylabel="scales", yticks=scales, **kw)
98. imshow(w[a:b, c:d], title="Phase transform, zoomed", yticks=scales[a:b], **kw)
99. #%%# zoom
100. wmn, wmx = w[w > 1e-3].min(), w[w > 1e-3].max()
101. # wmx += (wmx - wmn)*1.
102. imshow(w[a:b, c:d], title="Phase transform, magnitude-zoomed",
103.        yticks=scales[a:b], norm=(wmn, wmx), **kw)
104. #%%# Repeat for `k`
105. imshow(k, title=("Phase transform, index-equivalent, (min, max) = "
106.                  "(%d, %d)" % (k.min(), k.max())),
107.        ylabel="rows (scale indices)", yticks=idxs, **kw)
108. imshow(k[a:b, c:d], title="Phase transform, index-equivalent, zoomed",
109.        yticks=idxs[a:b], **kw)
110.