Advertisement
Guest User

Untitled

a guest
May 24th, 2019
149
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 7.07 KB | None | 0 0
  1. from vispy import gloo, app, visuals
  2.  
  3. import numpy as np
  4. import math
  5. from seaborn import color_palette
  6. from pylsl import StreamInlet, resolve_byprop
  7. from scipy.signal import lfilter, lfilter_zi
  8. from mne.filter import create_filter
  9. from .constants import LSL_SCAN_TIMEOUT, LSL_EEG_CHUNK
  10.  
  11.  
  12. def view():
  13.     print("Looking for an EEG stream...")
  14.     streams = resolve_byprop('type', 'EEG', timeout=LSL_SCAN_TIMEOUT)
  15.  
  16.     if len(streams) == 0:
  17.         raise(RuntimeError("Can't find EEG stream."))
  18.     print("Start acquiring data.")
  19.  
  20.     inlet = StreamInlet(streams[0], max_chunklen=LSL_EEG_CHUNK)
  21.     Canvas(inlet)
  22.     app.run()
  23.  
  24.  
  25. class Canvas(app.Canvas):
  26.     def __init__(self, lsl_inlet, scale=500, filt=True):
  27.         app.Canvas.__init__(self, title='EEG - Use your wheel to zoom!',
  28.                             keys='interactive')
  29.  
  30.         self.inlet = lsl_inlet
  31.         info = self.inlet.info()
  32.         description = info.desc()
  33.  
  34.         window = 10
  35.         self.sfreq = info.nominal_srate()
  36.         n_samples = int(self.sfreq * window)
  37.         self.n_chans = info.channel_count()
  38.  
  39.         ch = description.child('channels').first_child()
  40.         ch_names = [ch.child_value('label')]
  41.  
  42.         for i in range(self.n_chans):
  43.             ch = ch.next_sibling()
  44.             ch_names.append(ch.child_value('label'))
  45.  
  46.         # Number of cols and rows in the table.
  47.         n_rows = self.n_chans
  48.         n_cols = 1
  49.  
  50.         # Number of signals.
  51.         m = n_rows * n_cols
  52.  
  53.         # Number of samples per signal.
  54.         n = n_samples
  55.  
  56.         # Various signal amplitudes.
  57.         amplitudes = np.zeros((m, n)).astype(np.float32)
  58.         # gamma = np.ones((m, n)).astype(np.float32)
  59.         # Generate the signals as a (m, n) array.
  60.         y = amplitudes
  61.  
  62.         color = color_palette("RdBu_r", n_rows)
  63.  
  64.         color = np.repeat(color, n, axis=0).astype(np.float32)
  65.         # Signal 2D index of each vertex (row and col) and x-index (sample index
  66.         # within each signal).
  67.         index = np.c_[np.repeat(np.repeat(np.arange(n_cols), n_rows), n),
  68.                       np.repeat(np.tile(np.arange(n_rows), n_cols), n),
  69.                       np.tile(np.arange(n), m)].astype(np.float32)
  70.  
  71.         self.program = gloo.Program(VERT_SHADER, FRAG_SHADER)
  72.         self.program['a_position'] = y.reshape(-1, 1)
  73.         self.program['a_color'] = color
  74.         self.program['a_index'] = index
  75.         self.program['u_scale'] = (1., 1.)
  76.         self.program['u_size'] = (n_rows, n_cols)
  77.         self.program['u_n'] = n
  78.  
  79.         # text
  80.         self.font_size = 48.
  81.         self.names = []
  82.         self.quality = []
  83.         for ii in range(self.n_chans):
  84.             text = visuals.TextVisual(ch_names[ii], bold=True, color='white')
  85.             self.names.append(text)
  86.             text = visuals.TextVisual('', bold=True, color='white')
  87.             self.quality.append(text)
  88.  
  89.         self.quality_colors = color_palette("RdYlGn", 11)[::-1]
  90.  
  91.         self.scale = scale
  92.         self.n_samples = n_samples
  93.         self.filt = filt
  94.         self.af = [1.0]
  95.  
  96.         self.data_f = np.zeros((n_samples, self.n_chans))
  97.         self.data = np.zeros((n_samples, self.n_chans))
  98.  
  99.         self.bf = create_filter(self.data_f.T, self.sfreq, 3, 40.,
  100.                                 method='fir')
  101.  
  102.         zi = lfilter_zi(self.bf, self.af)
  103.         self.filt_state = np.tile(zi, (self.n_chans, 1)).transpose()
  104.  
  105.         self._timer = app.Timer('auto', connect=self.on_timer, start=True)
  106.         gloo.set_viewport(0, 0, *self.physical_size)
  107.         gloo.set_state(clear_color='black', blend=True,
  108.                        blend_func=('src_alpha', 'one_minus_src_alpha'))
  109.  
  110.         self.show()
  111.  
  112.     def on_key_press(self, event):
  113.  
  114.         # toggle filtering
  115.         if event.key.name == 'D':
  116.             self.filt = not self.filt
  117.  
  118.         # increase time scale
  119.         if event.key.name in ['+', '-']:
  120.             if event.key.name == '+':
  121.                 dx = -0.05
  122.             else:
  123.                 dx = 0.05
  124.             scale_x, scale_y = self.program['u_scale']
  125.             scale_x_new, scale_y_new = (scale_x * math.exp(1.0 * dx),
  126.                                         scale_y * math.exp(0.0 * dx))
  127.             self.program['u_scale'] = (
  128.                 max(1, scale_x_new), max(1, scale_y_new))
  129.             self.update()
  130.  
  131.     def on_mouse_wheel(self, event):
  132.         dx = np.sign(event.delta[1]) * .05
  133.         scale_x, scale_y = self.program['u_scale']
  134.         scale_x_new, scale_y_new = (scale_x * math.exp(0.0 * dx),
  135.                                     scale_y * math.exp(2.0 * dx))
  136.         self.program['u_scale'] = (max(1, scale_x_new), max(0.01, scale_y_new))
  137.         self.update()
  138.  
  139.     def on_timer(self, event):
  140.         """Add some data at the end of each signal (real-time signals)."""
  141.  
  142.         samples, timestamps = self.inlet.pull_chunk(timeout=0.0,
  143.                                                     max_samples=100)
  144.         if timestamps:
  145.             samples = np.array(samples)[:, ::-1]
  146.  
  147.             self.data = np.vstack([self.data, samples])
  148.             self.data = self.data[-self.n_samples:]
  149.             filt_samples, self.filt_state = lfilter(self.bf, self.af, samples,
  150.                                                     axis=0, zi=self.filt_state)
  151.             self.data_f = np.vstack([self.data_f, filt_samples])
  152.             self.data_f = self.data_f[-self.n_samples:]
  153.  
  154.             if self.filt:
  155.                 plot_data = self.data_f / self.scale
  156.             elif not self.filt:
  157.                 plot_data = (self.data - self.data.mean(axis=0)) / self.scale
  158.  
  159.             sd = np.std(plot_data[-int(self.sfreq):],
  160.                         axis=0)[::-1] * self.scale
  161.             co = np.int32(np.tanh((sd - 30) / 15) * 5 + 5)
  162.             for ii in range(self.n_chans):
  163.                 self.quality[ii].text = '%.2f' % (sd[ii])
  164.                 self.quality[ii].color = self.quality_colors[co[ii]]
  165.                 self.quality[ii].font_size = 12 + co[ii]
  166.  
  167.                 self.names[ii].font_size = 12 + co[ii]
  168.                 self.names[ii].color = self.quality_colors[co[ii]]
  169.  
  170.             self.program['a_position'].set_data(
  171.                 plot_data.T.ravel().astype(np.float32))
  172.             self.update()
  173.  
  174.     def on_resize(self, event):
  175.         # Set canvas viewport and reconfigure visual transforms to match.
  176.         vp = (0, 0, self.physical_size[0], self.physical_size[1])
  177.         self.context.set_viewport(*vp)
  178.  
  179.         for ii, t in enumerate(self.names):
  180.             t.transforms.configure(canvas=self, viewport=vp)
  181.             t.pos = (self.size[0] * 0.025,
  182.                      ((ii + 0.5) / self.n_chans) * self.size[1])
  183.  
  184.         for ii, t in enumerate(self.quality):
  185.             t.transforms.configure(canvas=self, viewport=vp)
  186.             t.pos = (self.size[0] * 0.975,
  187.                      ((ii + 0.5) / self.n_chans) * self.size[1])
  188.  
  189.     def on_draw(self, event):
  190.         gloo.clear()
  191.         gloo.set_viewport(0, 0, *self.physical_size)
  192.         self.program.draw('line_strip')
  193.         [t.draw() for t in self.names + self.quality]
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement