Advertisement
Guest User

Untitled

a guest
Apr 20th, 2019
89
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 15.10 KB | None | 0 0
  1. # author: ngayraud.
  2. #
  3. # Created on Tue Feb 20 11:14:54 2018.
  4.  
  5. import numpy as np
  6.  
  7. from .source import simulate_sparse_stc, select_source_in_label
  8. from ..forward import apply_forward_raw
  9. from ..utils import warn, logger
  10. from ..io import RawArray
  11. from ..io.constants import FIFF
  12. from mne import find_events, pick_types
  13. from .waveforms import get_waveform
  14. from .noise import generate_noise_data
  15.  
  16.  
  17. class Simulation(dict):
  18. """Simulation of meg/eeg data.
  19.  
  20. Parameters
  21. ----------
  22. fwd : Forward
  23. a forward solution containing an instance of Info and src
  24. n_dipoles : int
  25. Number of dipoles to simulate.
  26. labels : None | list of Labels
  27. The labels. The default is None, otherwise its size must be n_dipoles.
  28. location : str
  29. The label location to choose a dipole from. Can be ``random`` (default)
  30. or ``center`` to use :func:`mne.Label.center_of_mass`. Note that for
  31. ``center`` mode the label values are used as weights.
  32. subject : string | None
  33. The subject the label is defined for.
  34. Only used with location=``center``.
  35. subjects_dir : str, or None
  36. Path to the SUBJECTS_DIR. If None, the path is obtained by using the
  37. environment variable SUBJECTS_DIR. Only used with location=``center``.
  38. waveform: list of callables/str of length n_dipoles | str | callable
  39. To simulate a waveform (activity) on each dipole. If it is a string or
  40. a callable, the same activity will be generated over all dipoles
  41. window_times : array | list | str
  42. time window(s) to generate activity. If list, its size should be
  43. len(waveform). If str, should be ``all`` (default)
  44.  
  45. Notes
  46. -----
  47. Some notes.
  48. """
  49.  
  50. def __init__(self, fwd, n_dipoles=2, labels=None, location='random',
  51. subject=None, subjects_dir=None, waveform='sin',
  52. window_times='all'):
  53. self.fwd = fwd # TODO: check fwd
  54. labels, n_dipoles = self._get_sources(labels, n_dipoles)
  55. self.update(n_dipoles=n_dipoles, labels=labels, subject=subject,
  56. subjects_dir=subjects_dir, location=location,
  57. info=self.fwd['info'], coefficients=None)
  58. self['info']['projs'] = []
  59. self['info']['bads'] = []
  60. self['info']['sfreq'] = None
  61. self.waveforms = self._get_waveform(waveform)
  62. self.window_times = self._get_window_times(window_times)
  63.  
  64. def _get_sources(self, labels, n_dipoles):
  65. """Get labels and number of dipoles. Can be upgraded to do more.
  66.  
  67. Return a list of labels or None and the number of dipoles.
  68. """
  69. if labels is None:
  70. return labels, n_dipoles
  71.  
  72. n_labels = min(n_dipoles, len(labels))
  73. if n_dipoles != len(labels):
  74. warn('The number of labels is different from the number of '
  75. 'dipoles. %s dipole(s) will be generated.'
  76. % n_labels)
  77. labels = labels[:n_labels]
  78. return labels, n_labels
  79.  
  80. def _get_waveform(self, waveform):
  81. """Check the waveform given as imput wrt the number of dipoles.
  82.  
  83. Return a list of callables.
  84. """
  85. if isinstance(waveform, str):
  86. return [get_waveform(waveform)]
  87.  
  88. elif isinstance(waveform, list):
  89.  
  90. if len(waveform) > self['n_dipoles']:
  91. warn('The number of waveforms is greater from the number of '
  92. 'dipoles. %s waveform(s) will be generated.'
  93. % self['n_dipoles'])
  94. waveform = waveform[:self['n_dipoles']]
  95.  
  96. elif len(waveform) < self['n_dipoles']:
  97. raise ValueError('Found fewer waveforms than dipoles.')
  98. return [get_waveform(f) for f in waveform]
  99.  
  100. else:
  101. raise TypeError('Unrecognised type. Accepted inputs: str, list, '
  102. 'callable, or list containing any of the above.')
  103.  
  104. def _check_window_time(self, w_t):
  105. """Check if window time has the correct value and frequency."""
  106. if isinstance(w_t, np.ndarray):
  107. freq = np.floor(1. / (w_t[-1] - w_t[-2]))
  108. _check_frequency(self['info'], freq, 'The frequency of the '
  109. 'time windows is not the same')
  110. elif w_t is 'all':
  111. pass
  112. else:
  113. raise TypeError('Unrecognised type. Accepted inputs: array, '
  114. '\'all\'')
  115. return w_t
  116.  
  117. def _get_window_times(self, window_times):
  118. """Get a list of window_times."""
  119. if isinstance(window_times, list):
  120.  
  121. if len(window_times) > len(self.waveforms):
  122. n_func = len(self.waveforms)
  123. warn('The number of window times is greater than the number '
  124. 'of waveforms. %s waveform(s) will be generated.'
  125. % n_func)
  126. window_times = window_times[:n_func]
  127.  
  128. elif len(window_times) < len(self.waveforms):
  129. pad = len(self.waveforms) - len(window_times)
  130. warn('The number of window times is smaller than the number '
  131. 'of waveforms. Assuming that the last ones are \'all\'')
  132. window_times = window_times + ['all'] * pad
  133. else:
  134. window_times = [window_times]
  135.  
  136. return [self._check_window_time(w_t) for w_t in window_times]
  137.  
  138.  
  139. def set_sources(self, cf=None):
  140. if cf is None:
  141. cf = []
  142. vertno = [[], []]
  143. for i, label in enumerate(self['labels']):
  144. lh_vertno, rh_vertno = select_source_in_label(
  145. self.fwd['src'], label, None, self['location'],
  146. self['subject'], self['subjects_dir'], 'sphere')
  147. vertno[0] += lh_vertno
  148. vertno[1] += rh_vertno
  149. if len(lh_vertno) == 0 and len(rh_vertno) == 0:
  150. raise ValueError('No vertno found.')
  151. cf.append((lh_vertno, rh_vertno))
  152. self.update(coefficients=cf)
  153. return self
  154.  
  155.  
  156. def get_events(sim, times, events):
  157. """Get a list of events.
  158.  
  159. Checks if the input events correspond to the simulation times.
  160.  
  161. Parameters
  162. ----------
  163. sim : instance of Simulation
  164. Initialized Simulation object with parameters
  165. times : array
  166. Time array
  167. events : array, | list of arrays | None
  168. events corresponding to some stimulation. If array, its size should be
  169. shape=(len(times), 3). If list, its size should be len(n_dipoles). If
  170. None, defaults to no event (default)
  171.  
  172. Returns
  173. -------
  174. events : list | None
  175. a list of events of type array, shape=(n_events, 3)
  176. """
  177. if isinstance(events, list):
  178. n_waveforms = len(sim.waveforms)
  179. if len(events) > n_waveforms:
  180. warn('The number of event arrays is greater than the number '
  181. 'of waveforms. %s event arrays(s) will be generated.'
  182. % n_waveforms)
  183. events = events[:n_waveforms]
  184. elif len(events) < n_waveforms:
  185. pad = len(sim.waveforms) - len(events)
  186. warn('The number of event arrays is smaller than the number '
  187. 'of waveforms. Assuming that the last ones are None')
  188. events = events + [None] * pad
  189. else:
  190. events = [events]
  191.  
  192. return [_check_event(event, times) for event in events]
  193.  
  194.  
  195. def simulate_raw_signal(sim, times, cov=None, events=None, random_state=None,
  196. variabilities=None, verbose=None):
  197. """Simulate a raw signal.
  198.  
  199. Parameters
  200. ----------
  201. sim : instance of Simulation
  202. Initialized Simulation object with parameters
  203. times : array
  204. Time array
  205. cov : Covariance | string | dict | None
  206. Covariance of the noise
  207. events : array, shape = (n_events, 3) | list of arrays | None
  208. events corresponding to some stimulation.
  209. If list, its size should be len(n_dipoles)
  210. If None, defaults to no event (default)
  211. random_state : None | int | np.random.RandomState
  212. To specify the random generator state.
  213. variabilities : None | list of dict
  214. Will add latency variabilities to the events
  215. verbose : bool, str, int, or None
  216. If not None, override default verbose level (see :func:`mne.verbose`
  217. and :ref:`Logging documentation <tut_logging>` for more).
  218.  
  219. Returns
  220. -------
  221. raw : instance of RawArray
  222. The simulated raw file.
  223. """
  224. if len(times) <= 2: # to ensure event encoding works
  225. raise ValueError('stc must have at least three time points')
  226.  
  227. info = sim['info'].copy()
  228. freq = np.floor(1. / (times[-1] - times[-2]))
  229. _check_frequency(info, freq, 'The frequency of the time windows is not '
  230. 'the same as the experiment time. ')
  231. info['projs'] = []
  232. info['lowpass'] = None
  233.  
  234. # TODO: generate data for blinks and other physiological noise
  235.  
  236. raw_data = np.zeros((len(info['ch_names']), len(times)))
  237.  
  238. events = get_events(sim, times, events)
  239.  
  240. logger.info('Simulating signal from %s sources' % sim['n_dipoles'])
  241.  
  242. for dipoles, labels, window_time, event, variability, data_fun, cf in \
  243. _iterate_simulation_sources(sim, events, variabilities, times):
  244.  
  245. source_data = simulate_sparse_stc(sim.fwd['src'], dipoles,
  246. window_time, data_fun, labels,
  247. None, sim['location'],
  248. sim['subject'], sim['subjects_dir'],
  249. 'sphere', cf)
  250. propagation = _get_propagation(event, times, window_time, variability,
  251. freq)
  252. source_data.data = np.dot(source_data.data, propagation)
  253. raw_data += apply_forward_raw(sim.fwd, source_data, info,
  254. verbose=verbose).get_data()
  255.  
  256. # Noise
  257. if cov is not None:
  258. raw_data += generate_noise_data(info, cov, len(times), random_state)[0]
  259.  
  260. # Add an empty stimulation channel
  261. raw_data = np.vstack((raw_data, np.zeros((1, len(times)))))
  262. stim_chan = dict(ch_name='STI 014', coil_type=FIFF.FIFFV_COIL_NONE,
  263. kind=FIFF.FIFFV_STIM_CH, logno=len(info["chs"]) + 1,
  264. scanno=len(info["chs"]) + 1, cal=1., range=1.,
  265. loc=np.full(12, np.nan), unit=FIFF.FIFF_UNIT_NONE,
  266. unit_mul=0., coord_frame=FIFF.FIFFV_COORD_UNKNOWN)
  267. info['chs'].append(stim_chan)
  268. info._update_redundant()
  269.  
  270. # Create RawArray object with all data
  271. raw = RawArray(raw_data, info, first_samp=times[0], verbose=verbose)
  272.  
  273. # Update the stimulation channel with stimulations
  274. stimulations = [event for event in events if event is not None]
  275. if len(stimulations) != 0:
  276. stimulations = np.unique(np.vstack(stimulations), axis=0)
  277. # Add events onto a stimulation channel
  278. raw.add_events(stimulations, stim_channel='STI 014')
  279.  
  280. logger.info('Done')
  281. return raw
  282.  
  283.  
  284. def _get_propagation(event, times, window_time, variability=None, sfreq=None):
  285. """Return the matrix that propagates the waveforms."""
  286. propagation = 1.0
  287. if event is not None:
  288. # generate stimulation timeline
  289. stimulation_timeline = np.zeros(len(times))
  290.  
  291. #account for latency variability
  292. if variability is not None:
  293. stimulation_indices = np.zeros_like(event[:, 0], dtype=int)
  294. for i in range(len(event)):
  295. if variability['latency']==0:
  296. latency = 0
  297. else:
  298. lim = int(variability['latency']*sfreq)
  299. latency = np.random.randint(-lim, lim)
  300. stimulation_indices[i] = int(event[i, 0])+latency
  301.  
  302. else:
  303. stimulation_indices = np.array(event[:, 0], dtype=int)
  304. stimulation_timeline[stimulation_indices] = event[:, 2]
  305.  
  306. from scipy.linalg import toeplitz
  307. # Create toeplitz array. Equivalent to convoluting the signal with the
  308. # stimulation timeline
  309. index = stimulation_timeline != 0
  310. trig = np.zeros((len(times)))
  311. trig[index] = 1
  312. propagation = toeplitz(trig[0:len(window_time)], trig)
  313.  
  314. #account for amplitude variability
  315. if variability is not None:
  316. peak = variability['peak']
  317. amplitudes = np.zeros((len(window_time), len(event)))
  318. for i in range(len(event)):
  319. amplitude = (variability['amplitude'] * np.random.randn())
  320. amplitude = np.ones_like(window_time) + (amplitude * \
  321. np.exp(-(window_time - peak)**2 / 0.02))
  322. amplitudes[:,i]=amplitude
  323. for i in range(len(amplitudes)):
  324. to_mult = np.argwhere(propagation[i]==1)[:,0]
  325. propagation[i,to_mult] = amplitudes[i,:len(to_mult)]
  326.  
  327. return propagation
  328.  
  329.  
  330. def _check_frequency(info, freq, error_message):
  331. """Compare two frequency values and assert they are the same."""
  332. if info['sfreq'] is not None:
  333. if info['sfreq'] != freq:
  334. raise ValueError(error_message)
  335. else:
  336. info['sfreq'] = freq
  337. return True
  338.  
  339.  
  340. def _correct_window_times(w_t, e_t, times):
  341. """Check if window time has the correct length."""
  342. if (isinstance(w_t, str) and w_t == 'all') or e_t is None:
  343. return times
  344. else:
  345. if len(w_t) > len(times):
  346. warn('Window is too large, will be cut to match the '
  347. 'length of parameter \'times\'')
  348. return w_t[:len(times)]
  349.  
  350.  
  351. def _iterate_simulation_sources(sim, events, variabilities, times):
  352. """Iterate over all stimulation waveforms."""
  353. variabilities = [variabilities] if variabilities is None else variabilities
  354. if len(sim.waveforms) == 1:
  355. yield (sim['n_dipoles'], sim['labels'],
  356. _correct_window_times(sim.window_times[0], events[0], times),
  357. events[0], variabilities[0], sim.waveforms[0],
  358. sim['coefficients'])
  359. else:
  360. dipoles = 1
  361. for index, waveform in enumerate(sim.waveforms):
  362. n_wt = min(index, len(sim.window_times) - 1)
  363. n_ev = min(index, len(events) - 1)
  364. n_var = min(index, len(variabilities) - 1)
  365. labels = None
  366. if sim['labels'] is not None:
  367. labels = [sim['labels'][index]]
  368. if sim['coefficients'] is not None:
  369. cf = sim['coefficients'][index]
  370. yield (dipoles, labels,
  371. _correct_window_times(sim.window_times[n_wt], events[n_ev],
  372. times), events[n_ev],
  373. variabilities[n_var], waveform, cf)
  374.  
  375.  
  376. def _check_event(event, times):
  377. """Check if event array has the correct shape/length."""
  378. if isinstance(event, np.ndarray) and event.shape[1] == 3:
  379. if np.max(event) > len(times) - 1:
  380. warn('The indices in the event array is not the same as '
  381. 'the time points in the simulations.')
  382. event[np.where(event > len(times) - 1), 0] = len(times) - 1
  383. return np.array(event)
  384. elif event is not None:
  385. warn('Urecognized type. Will generated signal without events.')
  386. return None
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement