Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from brian2 import *
- from brian2.equations.equations import (DIFFERENTIAL_EQUATION, SingleEquation,
- PARAMETER, SUBEXPRESSION)
- def evaluate_rhs(eqs, values, namespace=None, level=0):
- """
- Evaluates the RHS of a system of differential equations for given state
- variable values. External constants can be provided via the namespace or
- will be taken from the local namespace.
- This function could be used for example to find a resting state of the
- system, i.e. a fixed point where the RHS of all equations are approximately
- 0.
- Parameters
- ----------
- eqs : `Equations`
- The equations
- values : dict-like
- Values for each of the state variables (differential equations and
- parameters).
- Returns
- -------
- rhs : dict
- A dictionary with the names of all variables defined by differential
- equations as keys and the respective RHS of the equations as values.
- """
- # Make a new set of equations, where differential equations are replaced
- # by parameters, and a new subexpression defines their RHS.
- # E.g. for 'dv/dt = -v / tau : volt' use:
- # '''v : volt
- # RHS_v = -v / tau : volt'''
- new_equations = []
- for eq in eqs.values():
- if eq.type == DIFFERENTIAL_EQUATION:
- new_equations.append(SingleEquation(PARAMETER, eq.varname,
- dimensions=eq.dim,
- var_type=eq.var_type))
- new_equations.append(SingleEquation(SUBEXPRESSION, 'RHS_'+eq.varname,
- dimensions=eq.dim/second.dim,
- var_type=eq.var_type,
- expr=eq.expr))
- else:
- new_equations.append(eq)
- if namespace is None:
- namespace = get_local_namespace(level+1)
- # TODO: Hide this from standalone mode
- group = NeuronGroup(1, model=Equations(new_equations),
- codeobj_class=NumpyCodeObject,
- namespace=namespace)
- # Set the values of the state variables/parameters
- group.set_states(values)
- # Get the values of all RHS_... subexpressions
- states = ['RHS_' + name for name in eqs.diff_eq_names]
- return group.get_states(states)
- if __name__ == '__main__':
- # Parameters
- area = 20000 * umetre ** 2
- Cm = 1 * ufarad * cm ** -2 * area
- gl = 5e-5 * siemens * cm ** -2 * area
- El = -65 * mV
- EK = -90 * mV
- ENa = 50 * mV
- g_na = 100 * msiemens * cm ** -2 * area
- g_kd = 30 * msiemens * cm ** -2 * area
- VT = -63 * mV
- I = 0.01*nA
- eqs = Equations('''
- dv/dt = (gl*(El-v) - g_na*(m*m*m)*h*(v-ENa) - g_kd*(n*n*n*n)*(v-EK) + I)/Cm : volt
- dm/dt = 0.32*(mV**-1)*(13.*mV-v+VT)/
- (exp((13.*mV-v+VT)/(4.*mV))-1.)/ms*(1-m)-0.28*(mV**-1)*(v-VT-40.*mV)/
- (exp((v-VT-40.*mV)/(5.*mV))-1.)/ms*m : 1
- dn/dt = 0.032*(mV**-1)*(15.*mV-v+VT)/
- (exp((15.*mV-v+VT)/(5.*mV))-1.)/ms*(1.-n)-.5*exp((10.*mV-v+VT)/(40.*mV))/ms*n : 1
- dh/dt = 0.128*exp((17.*mV-v+VT)/(18.*mV))/ms*(1.-h)-4./(1+exp((40.*mV-v+VT)/(5.*mV)))/ms*h : 1
- ''')
- # Find the resting state of this model
- def wrapper(args):
- rhs = evaluate_rhs(eqs, {'v': args[0]*volt,
- 'm': args[1],
- 'n': args[2],
- 'h': args[3]})
- return [float(rhs['RHS_v']),
- float(rhs['RHS_m']),
- float(rhs['RHS_n']),
- float(rhs['RHS_h'])]
- from scipy.optimize import root
- result = root(wrapper, x0=np.array([float(-70*mV), 0, 0, 0]))
- # Simulate neuron and compare resting state to calculated resting state
- group = NeuronGroup(1, eqs, method='exponential_euler')
- group.v = -70*mV
- mon = StateMonitor(group, ['v', 'm', 'n', 'h'], record=0)
- run(200*ms)
- fig, axes = plt.subplots(2, 2, sharex='all')
- axes[0, 0].plot(mon.t/ms, mon[0].v/mV, label='simulation')
- axes[0, 0].plot(200, result.x[0]*1000, 'rx', label='resting state')
- axes[0, 0].set(ylabel='v', xlabel='time (ms)')
- axes[0, 1].plot(mon.t/ms, mon[0].m, label='simulation')
- axes[0, 1].plot(200, result.x[1], 'rx', label='resting state')
- axes[0, 1].set(ylabel='m', xlabel='time (ms)')
- axes[1, 0].plot(mon.t/ms, mon[0].n, label='simulation')
- axes[1, 0].plot(200, result.x[2], 'rx', label='resting state')
- axes[1, 0].set(ylabel='n', xlabel='time (ms)')
- axes[1, 1].plot(mon.t/ms, mon[0].h, label='simulation')
- axes[1, 1].plot(200, result.x[3], 'rx', label='resting state')
- axes[1, 1].set(ylabel='h', xlabel='time (ms)')
- plt.tight_layout()
- plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement