Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from vispy import gloo, app, visuals
- import numpy as np
- import math
- from seaborn import color_palette
- from pylsl import StreamInlet, resolve_byprop
- from scipy.signal import lfilter, lfilter_zi
- from mne.filter import create_filter
- from .constants import LSL_SCAN_TIMEOUT, LSL_EEG_CHUNK
- def view():
- print("Looking for an EEG stream...")
- streams = resolve_byprop('type', 'EEG', timeout=LSL_SCAN_TIMEOUT)
- if len(streams) == 0:
- raise(RuntimeError("Can't find EEG stream."))
- print("Start acquiring data.")
- inlet = StreamInlet(streams[0], max_chunklen=LSL_EEG_CHUNK)
- Canvas(inlet)
- app.run()
- class Canvas(app.Canvas):
- def __init__(self, lsl_inlet, scale=500, filt=True):
- app.Canvas.__init__(self, title='EEG - Use your wheel to zoom!',
- keys='interactive')
- self.inlet = lsl_inlet
- info = self.inlet.info()
- description = info.desc()
- window = 10
- self.sfreq = info.nominal_srate()
- n_samples = int(self.sfreq * window)
- self.n_chans = info.channel_count()
- ch = description.child('channels').first_child()
- ch_names = [ch.child_value('label')]
- for i in range(self.n_chans):
- ch = ch.next_sibling()
- ch_names.append(ch.child_value('label'))
- # Number of cols and rows in the table.
- n_rows = self.n_chans
- n_cols = 1
- # Number of signals.
- m = n_rows * n_cols
- # Number of samples per signal.
- n = n_samples
- # Various signal amplitudes.
- amplitudes = np.zeros((m, n)).astype(np.float32)
- # gamma = np.ones((m, n)).astype(np.float32)
- # Generate the signals as a (m, n) array.
- y = amplitudes
- color = color_palette("RdBu_r", n_rows)
- color = np.repeat(color, n, axis=0).astype(np.float32)
- # Signal 2D index of each vertex (row and col) and x-index (sample index
- # within each signal).
- index = np.c_[np.repeat(np.repeat(np.arange(n_cols), n_rows), n),
- np.repeat(np.tile(np.arange(n_rows), n_cols), n),
- np.tile(np.arange(n), m)].astype(np.float32)
- self.program = gloo.Program(VERT_SHADER, FRAG_SHADER)
- self.program['a_position'] = y.reshape(-1, 1)
- self.program['a_color'] = color
- self.program['a_index'] = index
- self.program['u_scale'] = (1., 1.)
- self.program['u_size'] = (n_rows, n_cols)
- self.program['u_n'] = n
- # text
- self.font_size = 48.
- self.names = []
- self.quality = []
- for ii in range(self.n_chans):
- text = visuals.TextVisual(ch_names[ii], bold=True, color='white')
- self.names.append(text)
- text = visuals.TextVisual('', bold=True, color='white')
- self.quality.append(text)
- self.quality_colors = color_palette("RdYlGn", 11)[::-1]
- self.scale = scale
- self.n_samples = n_samples
- self.filt = filt
- self.af = [1.0]
- self.data_f = np.zeros((n_samples, self.n_chans))
- self.data = np.zeros((n_samples, self.n_chans))
- self.bf = create_filter(self.data_f.T, self.sfreq, 3, 40.,
- method='fir')
- zi = lfilter_zi(self.bf, self.af)
- self.filt_state = np.tile(zi, (self.n_chans, 1)).transpose()
- self._timer = app.Timer('auto', connect=self.on_timer, start=True)
- gloo.set_viewport(0, 0, *self.physical_size)
- gloo.set_state(clear_color='black', blend=True,
- blend_func=('src_alpha', 'one_minus_src_alpha'))
- self.show()
- def on_key_press(self, event):
- # toggle filtering
- if event.key.name == 'D':
- self.filt = not self.filt
- # increase time scale
- if event.key.name in ['+', '-']:
- if event.key.name == '+':
- dx = -0.05
- else:
- dx = 0.05
- scale_x, scale_y = self.program['u_scale']
- scale_x_new, scale_y_new = (scale_x * math.exp(1.0 * dx),
- scale_y * math.exp(0.0 * dx))
- self.program['u_scale'] = (
- max(1, scale_x_new), max(1, scale_y_new))
- self.update()
- def on_mouse_wheel(self, event):
- dx = np.sign(event.delta[1]) * .05
- scale_x, scale_y = self.program['u_scale']
- scale_x_new, scale_y_new = (scale_x * math.exp(0.0 * dx),
- scale_y * math.exp(2.0 * dx))
- self.program['u_scale'] = (max(1, scale_x_new), max(0.01, scale_y_new))
- self.update()
- def on_timer(self, event):
- """Add some data at the end of each signal (real-time signals)."""
- samples, timestamps = self.inlet.pull_chunk(timeout=0.0,
- max_samples=100)
- if timestamps:
- samples = np.array(samples)[:, ::-1]
- self.data = np.vstack([self.data, samples])
- self.data = self.data[-self.n_samples:]
- filt_samples, self.filt_state = lfilter(self.bf, self.af, samples,
- axis=0, zi=self.filt_state)
- self.data_f = np.vstack([self.data_f, filt_samples])
- self.data_f = self.data_f[-self.n_samples:]
- if self.filt:
- plot_data = self.data_f / self.scale
- elif not self.filt:
- plot_data = (self.data - self.data.mean(axis=0)) / self.scale
- sd = np.std(plot_data[-int(self.sfreq):],
- axis=0)[::-1] * self.scale
- co = np.int32(np.tanh((sd - 30) / 15) * 5 + 5)
- for ii in range(self.n_chans):
- self.quality[ii].text = '%.2f' % (sd[ii])
- self.quality[ii].color = self.quality_colors[co[ii]]
- self.quality[ii].font_size = 12 + co[ii]
- self.names[ii].font_size = 12 + co[ii]
- self.names[ii].color = self.quality_colors[co[ii]]
- self.program['a_position'].set_data(
- plot_data.T.ravel().astype(np.float32))
- self.update()
- def on_resize(self, event):
- # Set canvas viewport and reconfigure visual transforms to match.
- vp = (0, 0, self.physical_size[0], self.physical_size[1])
- self.context.set_viewport(*vp)
- for ii, t in enumerate(self.names):
- t.transforms.configure(canvas=self, viewport=vp)
- t.pos = (self.size[0] * 0.025,
- ((ii + 0.5) / self.n_chans) * self.size[1])
- for ii, t in enumerate(self.quality):
- t.transforms.configure(canvas=self, viewport=vp)
- t.pos = (self.size[0] * 0.975,
- ((ii + 0.5) / self.n_chans) * self.size[1])
- def on_draw(self, event):
- gloo.clear()
- gloo.set_viewport(0, 0, *self.physical_size)
- self.program.draw('line_strip')
- [t.draw() for t in self.names + self.quality]
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement