SHARE
TWEET

Untitled

a guest May 24th, 2019 70 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. from vispy import gloo, app, visuals
  2.  
  3. import numpy as np
  4. import math
  5. from seaborn import color_palette
  6. from pylsl import StreamInlet, resolve_byprop
  7. from scipy.signal import lfilter, lfilter_zi
  8. from mne.filter import create_filter
  9. from .constants import LSL_SCAN_TIMEOUT, LSL_EEG_CHUNK
  10.  
  11.  
  12. def view():
  13.     print("Looking for an EEG stream...")
  14.     streams = resolve_byprop('type', 'EEG', timeout=LSL_SCAN_TIMEOUT)
  15.  
  16.     if len(streams) == 0:
  17.         raise(RuntimeError("Can't find EEG stream."))
  18.     print("Start acquiring data.")
  19.  
  20.     inlet = StreamInlet(streams[0], max_chunklen=LSL_EEG_CHUNK)
  21.     Canvas(inlet)
  22.     app.run()
  23.  
  24.  
  25. class Canvas(app.Canvas):
  26.     def __init__(self, lsl_inlet, scale=500, filt=True):
  27.         app.Canvas.__init__(self, title='EEG - Use your wheel to zoom!',
  28.                             keys='interactive')
  29.  
  30.         self.inlet = lsl_inlet
  31.         info = self.inlet.info()
  32.         description = info.desc()
  33.  
  34.         window = 10
  35.         self.sfreq = info.nominal_srate()
  36.         n_samples = int(self.sfreq * window)
  37.         self.n_chans = info.channel_count()
  38.  
  39.         ch = description.child('channels').first_child()
  40.         ch_names = [ch.child_value('label')]
  41.  
  42.         for i in range(self.n_chans):
  43.             ch = ch.next_sibling()
  44.             ch_names.append(ch.child_value('label'))
  45.  
  46.         # Number of cols and rows in the table.
  47.         n_rows = self.n_chans
  48.         n_cols = 1
  49.  
  50.         # Number of signals.
  51.         m = n_rows * n_cols
  52.  
  53.         # Number of samples per signal.
  54.         n = n_samples
  55.  
  56.         # Various signal amplitudes.
  57.         amplitudes = np.zeros((m, n)).astype(np.float32)
  58.         # gamma = np.ones((m, n)).astype(np.float32)
  59.         # Generate the signals as a (m, n) array.
  60.         y = amplitudes
  61.  
  62.         color = color_palette("RdBu_r", n_rows)
  63.  
  64.         color = np.repeat(color, n, axis=0).astype(np.float32)
  65.         # Signal 2D index of each vertex (row and col) and x-index (sample index
  66.         # within each signal).
  67.         index = np.c_[np.repeat(np.repeat(np.arange(n_cols), n_rows), n),
  68.                       np.repeat(np.tile(np.arange(n_rows), n_cols), n),
  69.                       np.tile(np.arange(n), m)].astype(np.float32)
  70.  
  71.         self.program = gloo.Program(VERT_SHADER, FRAG_SHADER)
  72.         self.program['a_position'] = y.reshape(-1, 1)
  73.         self.program['a_color'] = color
  74.         self.program['a_index'] = index
  75.         self.program['u_scale'] = (1., 1.)
  76.         self.program['u_size'] = (n_rows, n_cols)
  77.         self.program['u_n'] = n
  78.  
  79.         # text
  80.         self.font_size = 48.
  81.         self.names = []
  82.         self.quality = []
  83.         for ii in range(self.n_chans):
  84.             text = visuals.TextVisual(ch_names[ii], bold=True, color='white')
  85.             self.names.append(text)
  86.             text = visuals.TextVisual('', bold=True, color='white')
  87.             self.quality.append(text)
  88.  
  89.         self.quality_colors = color_palette("RdYlGn", 11)[::-1]
  90.  
  91.         self.scale = scale
  92.         self.n_samples = n_samples
  93.         self.filt = filt
  94.         self.af = [1.0]
  95.  
  96.         self.data_f = np.zeros((n_samples, self.n_chans))
  97.         self.data = np.zeros((n_samples, self.n_chans))
  98.  
  99.         self.bf = create_filter(self.data_f.T, self.sfreq, 3, 40.,
  100.                                 method='fir')
  101.  
  102.         zi = lfilter_zi(self.bf, self.af)
  103.         self.filt_state = np.tile(zi, (self.n_chans, 1)).transpose()
  104.  
  105.         self._timer = app.Timer('auto', connect=self.on_timer, start=True)
  106.         gloo.set_viewport(0, 0, *self.physical_size)
  107.         gloo.set_state(clear_color='black', blend=True,
  108.                        blend_func=('src_alpha', 'one_minus_src_alpha'))
  109.  
  110.         self.show()
  111.  
  112.     def on_key_press(self, event):
  113.  
  114.         # toggle filtering
  115.         if event.key.name == 'D':
  116.             self.filt = not self.filt
  117.  
  118.         # increase time scale
  119.         if event.key.name in ['+', '-']:
  120.             if event.key.name == '+':
  121.                 dx = -0.05
  122.             else:
  123.                 dx = 0.05
  124.             scale_x, scale_y = self.program['u_scale']
  125.             scale_x_new, scale_y_new = (scale_x * math.exp(1.0 * dx),
  126.                                         scale_y * math.exp(0.0 * dx))
  127.             self.program['u_scale'] = (
  128.                 max(1, scale_x_new), max(1, scale_y_new))
  129.             self.update()
  130.  
  131.     def on_mouse_wheel(self, event):
  132.         dx = np.sign(event.delta[1]) * .05
  133.         scale_x, scale_y = self.program['u_scale']
  134.         scale_x_new, scale_y_new = (scale_x * math.exp(0.0 * dx),
  135.                                     scale_y * math.exp(2.0 * dx))
  136.         self.program['u_scale'] = (max(1, scale_x_new), max(0.01, scale_y_new))
  137.         self.update()
  138.  
  139.     def on_timer(self, event):
  140.         """Add some data at the end of each signal (real-time signals)."""
  141.  
  142.         samples, timestamps = self.inlet.pull_chunk(timeout=0.0,
  143.                                                     max_samples=100)
  144.         if timestamps:
  145.             samples = np.array(samples)[:, ::-1]
  146.  
  147.             self.data = np.vstack([self.data, samples])
  148.             self.data = self.data[-self.n_samples:]
  149.             filt_samples, self.filt_state = lfilter(self.bf, self.af, samples,
  150.                                                     axis=0, zi=self.filt_state)
  151.             self.data_f = np.vstack([self.data_f, filt_samples])
  152.             self.data_f = self.data_f[-self.n_samples:]
  153.  
  154.             if self.filt:
  155.                 plot_data = self.data_f / self.scale
  156.             elif not self.filt:
  157.                 plot_data = (self.data - self.data.mean(axis=0)) / self.scale
  158.  
  159.             sd = np.std(plot_data[-int(self.sfreq):],
  160.                         axis=0)[::-1] * self.scale
  161.             co = np.int32(np.tanh((sd - 30) / 15) * 5 + 5)
  162.             for ii in range(self.n_chans):
  163.                 self.quality[ii].text = '%.2f' % (sd[ii])
  164.                 self.quality[ii].color = self.quality_colors[co[ii]]
  165.                 self.quality[ii].font_size = 12 + co[ii]
  166.  
  167.                 self.names[ii].font_size = 12 + co[ii]
  168.                 self.names[ii].color = self.quality_colors[co[ii]]
  169.  
  170.             self.program['a_position'].set_data(
  171.                 plot_data.T.ravel().astype(np.float32))
  172.             self.update()
  173.  
  174.     def on_resize(self, event):
  175.         # Set canvas viewport and reconfigure visual transforms to match.
  176.         vp = (0, 0, self.physical_size[0], self.physical_size[1])
  177.         self.context.set_viewport(*vp)
  178.  
  179.         for ii, t in enumerate(self.names):
  180.             t.transforms.configure(canvas=self, viewport=vp)
  181.             t.pos = (self.size[0] * 0.025,
  182.                      ((ii + 0.5) / self.n_chans) * self.size[1])
  183.  
  184.         for ii, t in enumerate(self.quality):
  185.             t.transforms.configure(canvas=self, viewport=vp)
  186.             t.pos = (self.size[0] * 0.975,
  187.                      ((ii + 0.5) / self.n_chans) * self.size[1])
  188.  
  189.     def on_draw(self, event):
  190.         gloo.clear()
  191.         gloo.set_viewport(0, 0, *self.physical_size)
  192.         self.program.draw('line_strip')
  193.         [t.draw() for t in self.names + self.quality]
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top