Guest User

Untitled

a guest
Dec 10th, 2018
258
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C 20.08 KB | None | 0 0
  1. #pragma OPENCL EXTENSION cl_khr_global_int32_base_atomics : enable
  2. #pragma OPENCL EXTENSION cl_khr_local_int32_base_atomics : enable
  3. #pragma OPENCL EXTENSION cl_amd_printf : enable
  4.  
  5. uint TausStep(uint* z, int S1, int S2, int S3, uint M)
  6. {
  7.     uint b = (((*z << S1) ^ *z) >> S2);
  8.     return *z = (((*z & M) << S3) ^ b);
  9. }
  10.  
  11. uint LCGStep(uint* z, uint A, uint C)
  12. {
  13.     return *z = (A * *z + C);
  14. }
  15.  
  16. float rand2(uint2* zp)
  17. {
  18.     uint z0 = (*zp).s0;
  19.     uint z1 = (*zp).s1;
  20.     float ret = 2.3283064365387e-10 * (TausStep(&z0, 13, 19, 12, 4294967294UL)
  21.                                        ^ LCGStep(&z1, 1664525, 1013904223UL));
  22.     (*zp).s0 = z0;
  23.     (*zp).s1 = z1;
  24.     return ret;
  25. }
  26.  
  27. #undef record
  28. #define record(flags, data) \
  29. if(_record_flags & (flags)) \
  30. { \
  31.     int idx = atomic_inc(&_local_record_idx); \
  32.     if(idx >= _record_buffer_size) \
  33.     { \
  34.         _error_buffer[i + 1] = 1; \
  35.         _local_record_idx = _record_buffer_size - 1; \
  36.     } \
  37.     float4 record; \
  38.     record.s0 = t; \
  39.     record.s1 = (data); \
  40.     record.s2 = flags; \
  41.     record.s3 = i; \
  42.     _record_buffer[idx + _local_record_idx_start] = record; \
  43. } \
  44.  
  45.  
  46. __kernel void RS_step
  47. (
  48.     const float _t,
  49.     __global int* _error_buffer,
  50.     __global int* _record_flags_buffer,
  51.     __global int* _record_idx,
  52.     __global int* _record_idx_start,
  53.     __global float4* _record_buffer,
  54.     const int _record_buffer_size,
  55.     __global float* _rwvalues_0_buf,
  56.     __global float* _rwvalues_1_buf,
  57.     __global float* _rwvalues_2_buf,
  58.     __global float* _rwvalues_3_buf,
  59.     __global float* _rwvalues_4_buf,
  60.     __global float* _rwvalues_5_buf,
  61.     __global float* _rwvalues_6_buf,
  62.     __global float* _rwvalues_7_buf,
  63.     __global float* _rwvalues_8_buf,
  64.     __global float* _rwvalues_9_buf,
  65.     __global float* _rwvalues_10_buf,
  66.  
  67.     __global float* _rovalues_0_buf,
  68.     __global float* _rovalues_1_buf,
  69.     __global float* _rovalues_2_buf,
  70.     __global float* _rovalues_3_buf,
  71.     __global float* _rovalues_4_buf,
  72.     __global float* _rovalues_5_buf,
  73.     __global float* _rovalues_6_buf,
  74.     __global float* _rovalues_7_buf,
  75.     __global float* _rovalues_8_buf,
  76.     __global float* _rovalues_9_buf,
  77.     __global float* _rovalues_10_buf,
  78.     __global float* _rovalues_11_buf,
  79.     __global float* _rovalues_12_buf,
  80.     __global float* _rovalues_13_buf,
  81.     __global float* _rovalues_14_buf,
  82.  
  83.     const float stdp_Ap,
  84.     const float stdp_Am,
  85.     const float stdp_o1_tau,
  86.     const float stdp_o2_tau,
  87.     const float stdp_r1_tau,
  88.     const float stdp_r2_tau,
  89.     __global uint2* _rand_state_buf,
  90.     __global float* _dt_buf,
  91.     __global int* _circ_buffer_start,
  92.     __global int* _circ_buffer_end,
  93.     __global float* _circ_buffer,
  94.     __global int* _fired_syn_idx_buffer,
  95.     __global int* _fired_syn_buffer,
  96.     __global float* _static_excitatory_on_buf,
  97.     __global float* _static_excitatory_weight_buf,
  98.     __global float* _static_inhibitory_weight_buf,
  99.     __global float* _stdp_weight_buf,
  100.     __global float* _stdp_r1_buf,
  101.     __global float* _stdp_r2_buf,
  102.     __global float* _stdp_last_r_time_buf,
  103.     __local int* _syn_thresh_0_arr,
  104.     const int count
  105. )
  106. {
  107.     int i = get_global_id(0);
  108.     int _local_id = get_local_id(0);
  109.     int _group_id = get_group_id(0);
  110.     const float V_tol = 2.000000e-01;
  111.     const float u_tol = 2.000000e-02;
  112.     const float axon_delay = 1.000000e+00;
  113.     const float rand_syn_r_tol = 1.000000e-02;
  114.     const float static_excitatory_E = 0.000000e+00;
  115.     const float static_excitatory_tau1 = 2.000000e+00;
  116.     const float static_excitatory_tau2 = 5.000000e+01;
  117.     const float static_excitatory_N = 1.000000e+01;
  118.     const float static_excitatory_P = 2.500000e-01;
  119.     const float static_inhibitory_E = -8.000000e+01;
  120.     const float static_inhibitory_gsyn = 2.000000e+00;
  121.     const float static_inhibitory_tau = 1.000000e+01;
  122.     const float static_inhibitory_N = 1.000000e+01;
  123.     const float static_inhibitory_P = 2.500000e-01;
  124.     const float stdp_s_tol = 1.000000e-02;
  125.     const float stdp_E = -8.000000e+01;
  126.     const float stdp_gsyn = 2.000000e-01;
  127.     const float stdp_tau = 1.000000e+01;
  128.     const float stdp_N = 1.000000e+01;
  129.     const float stdp_P = 2.500000e-01;
  130.     const float static_excitatory_s1_tol = 0.1;
  131.     const float static_excitatory_s2_tol = 0.1;
  132.     const float static_inhibitory_s_tol = 0.1;
  133.     __local int _syn_thresh_0_num;
  134.     int _local_size = get_local_size(0);
  135.     __local int _num_complete;
  136.  
  137.     if(_local_id == 0)
  138.         _num_complete = 0;
  139.  
  140.     float _cur_time = 0;
  141.     const float timestep = 1;
  142.     int _record_flags = _record_flags_buffer[i];
  143.     __local int _local_record_idx;
  144.     __local int _local_record_idx_start;
  145.  
  146.     if(_local_id == 0)
  147.     {
  148.         _local_record_idx = _record_idx[_group_id];
  149.         _local_record_idx_start = _record_idx_start[_group_id];
  150.     }
  151.  
  152.     float _dt;
  153.     float t = _t;
  154.     float _dt_residual = 0;
  155.     _dt = _dt_buf[i];
  156.     float _rwvalues_0 = _rwvalues_0_buf[i];
  157.     float V = _rwvalues_0;
  158.     float _rwvalues_1 = _rwvalues_1_buf[i];
  159.     float u = _rwvalues_1;
  160.     float _rwvalues_2 = _rwvalues_2_buf[i];
  161.     float rand_syn_r = _rwvalues_2;
  162.     float _rwvalues_3 = _rwvalues_3_buf[i];
  163.     float static_excitatory_s1 = _rwvalues_3;
  164.     float _rwvalues_4 = _rwvalues_4_buf[i];
  165.     float static_excitatory_s2 = _rwvalues_4;
  166.     float _rwvalues_5 = _rwvalues_5_buf[i];
  167.     float static_inhibitory_s = _rwvalues_5;
  168.     float _rwvalues_6 = _rwvalues_6_buf[i];
  169.     float stdp_s = _rwvalues_6;
  170.     float _rwvalues_7 = _rwvalues_7_buf[i];
  171.     float stdp_o1 = _rwvalues_7;
  172.     float _rwvalues_8 = _rwvalues_8_buf[i];
  173.     float stdp_o2 = _rwvalues_8;
  174.     float _rwvalues_9 = _rwvalues_9_buf[i];
  175.     float stdp_last_o_time = _rwvalues_9;
  176.     float _rwvalues_10 = _rwvalues_10_buf[i];
  177.     float stdp_active = _rwvalues_10;
  178.     float _rovalues_0 = _rovalues_0_buf[i];
  179.     float a = _rovalues_0;
  180.     float _rovalues_1 = _rovalues_1_buf[i];
  181.     float b = _rovalues_1;
  182.     float _rovalues_2 = _rovalues_2_buf[i];
  183.     float c = _rovalues_2;
  184.     float _rovalues_3 = _rovalues_3_buf[i];
  185.     float d = _rovalues_3;
  186.     float _rovalues_4 = _rovalues_4_buf[i];
  187.     float start_amp = _rovalues_4;
  188.     float _rovalues_5 = _rovalues_5_buf[i];
  189.     float end_amp = _rovalues_5;
  190.     float _rovalues_6 = _rovalues_6_buf[i];
  191.     float start_t = _rovalues_6;
  192.     float _rovalues_7 = _rovalues_7_buf[i];
  193.     float end_t = _rovalues_7;
  194.     float _rovalues_8 = _rovalues_8_buf[i];
  195.     float rand_syn_tau = _rovalues_8;
  196.     float _rovalues_9 = _rovalues_9_buf[i];
  197.     float rand_syn_weight = _rovalues_9;
  198.     float _rovalues_10 = _rovalues_10_buf[i];
  199.     float rand_syn_rate = _rovalues_10;
  200.     float _rovalues_11 = _rovalues_11_buf[i];
  201.     float rand_syn_frequency = _rovalues_11;
  202.     float _rovalues_12 = _rovalues_12_buf[i];
  203.     float rand_syn_sine_amp = _rovalues_12;
  204.     float _rovalues_13 = _rovalues_13_buf[i];
  205.     float static_excitatory_gsyn1 = _rovalues_13;
  206.     float _rovalues_14 = _rovalues_14_buf[i];
  207.     float static_excitatory_gsyn2 = _rovalues_14;
  208.     uint2 _rand_state = _rand_state_buf[i];
  209.     const int _syn_offset = 0 + i * 1250;
  210.     int _syn_table_end = _fired_syn_idx_buffer[i + 0];
  211.  
  212.     if(_syn_table_end != _syn_offset)
  213.     {
  214.         for(int _syn_table_idx = _syn_offset; _syn_table_idx < _syn_table_end; _syn_table_idx++)
  215.         {
  216.             int syn_i = _fired_syn_buffer[_syn_table_idx];
  217.  
  218.             if(syn_i < 700)
  219.             {
  220.                 int _g_syn_i = syn_i - 0 + i * 700;
  221.                 /* Load values */
  222.                 float static_excitatory_on = _static_excitatory_on_buf[_g_syn_i];
  223.                 float static_excitatory_weight = _static_excitatory_weight_buf[_g_syn_i];
  224.                 /* Syn code */
  225.                 float frac = (sqrt(-2 * log(rand2(&_rand_state) + 0.000001)) * cospi(2 * rand2(&_rand_state))) * sqrt(static_excitatory_N * static_excitatory_P * (1.0 - static_excitatory_P)) + static_excitatory_N * static_excitatory_P;
  226.                 frac = fmax(frac, 0.0f);
  227.                 static_excitatory_s1 += static_excitatory_on * frac * static_excitatory_gsyn1 * static_excitatory_weight;
  228.                 static_excitatory_s2 += static_excitatory_on * frac * static_excitatory_gsyn2 * static_excitatory_weight;
  229.                 //record(2, static_excitatory_weight, 100+syn_i); //flags, val, tag
  230.                 /* Save values */
  231.             }
  232.             else if(syn_i < 750)
  233.             {
  234.                 int _g_syn_i = syn_i - 700 + i * 50;
  235.                 /* Load values */
  236.                 float static_inhibitory_weight = _static_inhibitory_weight_buf[_g_syn_i];
  237.                 /* Syn code */
  238.                 float frac = (sqrt(-2 * log(rand2(&_rand_state) + 0.000001)) * cospi(2 * rand2(&_rand_state))) * sqrt(static_inhibitory_N * static_inhibitory_P * (1.0 - static_inhibitory_P)) + static_inhibitory_N * static_inhibitory_P;
  239.                 frac = fmax(frac, 0.0f);
  240.                 static_inhibitory_s += frac * static_inhibitory_gsyn * static_inhibitory_weight;
  241.                 //record(2, static_inhibitory_weight, 100+syn_i); //flags, val, tag
  242.                 /* Save values */
  243.             }
  244.             else if(syn_i < 1250)
  245.             {
  246.                 int _g_syn_i = syn_i - 750 + i * 500;
  247.                 /* Load values */
  248.                 float stdp_weight = _stdp_weight_buf[_g_syn_i];
  249.                 float stdp_r1 = _stdp_r1_buf[_g_syn_i];
  250.                 float stdp_r2 = _stdp_r2_buf[_g_syn_i];
  251.                 float stdp_last_r_time = _stdp_last_r_time_buf[_g_syn_i];
  252.                 /* Syn code */
  253.                 stdp_o1 = stdp_o1 * exp(-(t - stdp_last_o_time) / stdp_o1_tau);
  254.                 stdp_o2 = stdp_o2 * exp(-(t - stdp_last_o_time) / stdp_o2_tau);
  255.                 stdp_last_o_time = t;
  256.  
  257.                 if(stdp_active > 0)
  258.                     stdp_weight -= stdp_Am * (stdp_o1 - stdp_o2);
  259.  
  260.                 stdp_weight = fmax(stdp_weight, 0.0f);
  261.                 float frac = (sqrt(-2 * log(rand2(&_rand_state) + 0.000001)) * cospi(2 * rand2(&_rand_state))) * sqrt(stdp_N * stdp_P * (1.0 - stdp_P)) + stdp_N * stdp_P;
  262.                 frac = fmax(frac, 0.0f);
  263.                 stdp_s += frac * stdp_gsyn * stdp_weight;
  264.                 stdp_r1 = stdp_r1 * exp(-(t - stdp_last_r_time) / stdp_r1_tau) + 1;
  265.                 stdp_r2 = stdp_r2 * exp(-(t - stdp_last_r_time) / stdp_r2_tau) + 1;
  266.                 stdp_last_r_time = t;
  267.                 /* Save values */
  268.                 _stdp_weight_buf[_g_syn_i] = stdp_weight;
  269.                 _stdp_r1_buf[_g_syn_i] = stdp_r1;
  270.                 _stdp_r2_buf[_g_syn_i] = stdp_r2;
  271.                 _stdp_last_r_time_buf[_g_syn_i] = stdp_last_r_time;
  272.             }
  273.         }
  274.  
  275.         _dt = 0.1f;
  276.         _fired_syn_idx_buffer[i + 0] = _syn_offset;
  277.     }
  278.  
  279.     {
  280.         if(rand_syn_weight > 0)
  281.         {
  282.             if(rand2(&_rand_state) < timestep * (rand_syn_rate + rand_syn_sine_amp * sin(2.0 * M_PI_F * t / 1000.0 * rand_syn_frequency)) / 1000.0)
  283.                 rand_syn_r += rand_syn_weight;
  284.  
  285.             _dt = 0.1f;
  286.         }
  287.     }
  288.  
  289.     barrier(CLK_LOCAL_MEM_FENCE);
  290.  
  291.     while(_num_complete < _local_size)
  292.     {
  293.         bool _any_thresh = false;
  294.         /* Threshold statuses */
  295.         bool thresh_0_state = false;
  296.         bool thresh_1_state = false;
  297.  
  298.         if(_local_id == 0)
  299.         {
  300.             _syn_thresh_0_num = 0;
  301.         }
  302.  
  303.         barrier(CLK_LOCAL_MEM_FENCE);
  304.  
  305.         while(!_any_thresh && _cur_time < timestep)
  306.         {
  307.             t = _t + _cur_time;
  308.  
  309.             /* Post-thresh integrator code */
  310.             /* Clamp the _dt not too overshoot the timestep */
  311.             if(_cur_time < timestep && _cur_time + _dt >= timestep)
  312.             {
  313.                 _dt_residual = _dt;
  314.                 _dt = timestep - _cur_time + 0.0001f;
  315.                 _dt_residual -= _dt;
  316.             }
  317.  
  318.             float _error = 0;
  319.             /* See where the thresholded states are before changing them (doesn't work for synapse states)*/
  320.             bool thresh_0_pre_state = V > 0;
  321.             bool thresh_1_pre_state = V > 0;
  322.             bool syn_thresh_0_pre_state = V > 0;
  323.             /* Declare local variables */
  324.             float I;
  325.             /* Pre-stage code */
  326.             {
  327.                 record(1, V); //flags, val, tag
  328.             }
  329.             {
  330.                 if(t == start_t)
  331.                     _dt = 0.1f;
  332.                 else if(t == end_t)
  333.                     _dt = 0.1f;
  334.             }
  335.             /* Integrator code */
  336.             /* Declare storage for first state estimate*/
  337.             float _V_0 = V;
  338.             float _u_0 = u;
  339.             float _rand_syn_r_0 = rand_syn_r;
  340.             float _static_excitatory_s1_0 = static_excitatory_s1;
  341.             float _static_excitatory_s2_0 = static_excitatory_s2;
  342.             float _static_inhibitory_s_0 = static_inhibitory_s;
  343.             float _stdp_s_0 = stdp_s;
  344.             /* First derivative stage */
  345.             float _dV_dt_1;
  346.             float _du_dt_1;
  347.             float _drand_syn_r_dt_1;
  348.             float _dstatic_excitatory_s1_dt_1;
  349.             float _dstatic_excitatory_s2_dt_1;
  350.             float _dstatic_inhibitory_s_dt_1;
  351.             float _dstdp_s_dt_1;
  352.             /* Second derivative stage */
  353.             float _dV_dt_2;
  354.             float _du_dt_2;
  355.             float _drand_syn_r_dt_2;
  356.             float _dstatic_excitatory_s1_dt_2;
  357.             float _dstatic_excitatory_s2_dt_2;
  358.             float _dstatic_inhibitory_s_dt_2;
  359.             float _dstdp_s_dt_2;
  360.             /* Compute the first derivatives */
  361.             {
  362.                 I = 0;
  363.             }
  364.             {
  365.                 if(t >= start_t && t < end_t)
  366.                 {
  367.                     I += (t - start_t) / (end_t - start_t) * (end_amp - start_amp) + start_amp;
  368.                 }
  369.             }
  370.             {
  371.                 I += rand_syn_r * (0 - V);
  372.             }
  373.             {
  374.                 I += (static_excitatory_s1 + static_excitatory_s2 * 1.0f / (1.0f + 1.0f / 3.57 * exp(-0.062 * V))) * (static_excitatory_E - V);
  375.             }
  376.             {
  377.                 I += static_inhibitory_s * (static_inhibitory_E - V);
  378.             }
  379.             {
  380.                 I += stdp_s * (stdp_E - V);
  381.             }
  382.             {
  383.                 _dV_dt_1 = (0.04f * V + 5) * V + 140 - u + I;
  384.                 _du_dt_1 = a * (b * V - u);
  385.             }
  386.             {
  387.                 _drand_syn_r_dt_1 = -rand_syn_r / rand_syn_tau;
  388.             }
  389.             {
  390.                 _dstatic_excitatory_s1_dt_1 = -static_excitatory_s1 / static_excitatory_tau1;
  391.                 _dstatic_excitatory_s2_dt_1 = -static_excitatory_s2 / static_excitatory_tau2;
  392.             }
  393.             {
  394.                 _dstatic_inhibitory_s_dt_1 = -static_inhibitory_s / static_inhibitory_tau;
  395.             }
  396.             {
  397.                 _dstdp_s_dt_1 = -stdp_s / stdp_tau;
  398.             }
  399.             /* Compute the first state estimate */
  400.             _dV_dt_1 *= _dt;
  401.             _V_0 += _dV_dt_1;
  402.             _du_dt_1 *= _dt;
  403.             _u_0 += _du_dt_1;
  404.             _drand_syn_r_dt_1 *= _dt;
  405.             _rand_syn_r_0 += _drand_syn_r_dt_1;
  406.             _dstatic_excitatory_s1_dt_1 *= _dt;
  407.             _static_excitatory_s1_0 += _dstatic_excitatory_s1_dt_1;
  408.             _dstatic_excitatory_s2_dt_1 *= _dt;
  409.             _static_excitatory_s2_0 += _dstatic_excitatory_s2_dt_1;
  410.             _dstatic_inhibitory_s_dt_1 *= _dt;
  411.             _static_inhibitory_s_0 += _dstatic_inhibitory_s_dt_1;
  412.             _dstdp_s_dt_1 *= _dt;
  413.             _stdp_s_0 += _dstdp_s_dt_1;
  414.             /* Compute the derivatives again */
  415.             {
  416.                 I = 0;
  417.             }
  418.             {
  419.                 if(t >= start_t && t < end_t)
  420.                 {
  421.                     I += (t - start_t) / (end_t - start_t) * (end_amp - start_amp) + start_amp;
  422.                 }
  423.             }
  424.             {
  425.                 I += _rand_syn_r_0 * (0 - _V_0);
  426.             }
  427.             {
  428.                 I += (_static_excitatory_s1_0 + _static_excitatory_s2_0 * 1.0f / (1.0f + 1.0f / 3.57 * exp(-0.062 * _V_0))) * (static_excitatory_E - _V_0);
  429.             }
  430.             {
  431.                 I += _static_inhibitory_s_0 * (static_inhibitory_E - _V_0);
  432.             }
  433.             {
  434.                 I += _stdp_s_0 * (stdp_E - _V_0);
  435.             }
  436.             {
  437.                 _dV_dt_2 = (0.04f * _V_0 + 5) * _V_0 + 140 - _u_0 + I;
  438.                 _du_dt_2 = a * (b * _V_0 - _u_0);
  439.             }
  440.             {
  441.                 _drand_syn_r_dt_2 = -_rand_syn_r_0 / rand_syn_tau;
  442.             }
  443.             {
  444.                 _dstatic_excitatory_s1_dt_2 = -_static_excitatory_s1_0 / static_excitatory_tau1;
  445.                 _dstatic_excitatory_s2_dt_2 = -_static_excitatory_s2_0 / static_excitatory_tau2;
  446.             }
  447.             {
  448.                 _dstatic_inhibitory_s_dt_2 = -_static_inhibitory_s_0 / static_inhibitory_tau;
  449.             }
  450.             {
  451.                 _dstdp_s_dt_2 = -_stdp_s_0 / stdp_tau;
  452.             }
  453.             /* Compute the final change in state */
  454.             _dV_dt_2 = (_dV_dt_1 + _dt * _dV_dt_2) / 2;
  455.             _du_dt_2 = (_du_dt_1 + _dt * _du_dt_2) / 2;
  456.             _drand_syn_r_dt_2 = (_drand_syn_r_dt_1 + _dt * _drand_syn_r_dt_2) / 2;
  457.             _dstatic_excitatory_s1_dt_2 = (_dstatic_excitatory_s1_dt_1 + _dt * _dstatic_excitatory_s1_dt_2) / 2;
  458.             _dstatic_excitatory_s2_dt_2 = (_dstatic_excitatory_s2_dt_1 + _dt * _dstatic_excitatory_s2_dt_2) / 2;
  459.             _dstatic_inhibitory_s_dt_2 = (_dstatic_inhibitory_s_dt_1 + _dt * _dstatic_inhibitory_s_dt_2) / 2;
  460.             _dstdp_s_dt_2 = (_dstdp_s_dt_1 + _dt * _dstdp_s_dt_2) / 2;
  461.             /* Compute the error in this step */
  462.             _error = max(_error, fabs(_dV_dt_1 - _dV_dt_2) / V_tol);
  463.             _error = max(_error, fabs(_du_dt_1 - _du_dt_2) / u_tol);
  464.             _error = max(_error, fabs(_drand_syn_r_dt_1 - _drand_syn_r_dt_2) / rand_syn_r_tol);
  465.             _error = max(_error, fabs(_dstatic_excitatory_s1_dt_1 - _dstatic_excitatory_s1_dt_2) / static_excitatory_s1_tol);
  466.             _error = max(_error, fabs(_dstatic_excitatory_s2_dt_1 - _dstatic_excitatory_s2_dt_2) / static_excitatory_s2_tol);
  467.             _error = max(_error, fabs(_dstatic_inhibitory_s_dt_1 - _dstatic_inhibitory_s_dt_2) / static_inhibitory_s_tol);
  468.             _error = max(_error, fabs(_dstdp_s_dt_1 - _dstdp_s_dt_2) / stdp_s_tol);
  469.             /* Update state */
  470.             V += _dV_dt_2;
  471.             u += _du_dt_2;
  472.             rand_syn_r += _drand_syn_r_dt_2;
  473.             static_excitatory_s1 += _dstatic_excitatory_s1_dt_2;
  474.             static_excitatory_s2 += _dstatic_excitatory_s2_dt_2;
  475.             static_inhibitory_s += _dstatic_inhibitory_s_dt_2;
  476.             stdp_s += _dstdp_s_dt_2;
  477.             /* Advance and compute the new step size*/
  478.             _cur_time += _dt;
  479.  
  480.             if(_error == 0)
  481.                 _dt = timestep;
  482.             else
  483.             {
  484.                 /* Approximate the cube root using Halley's Method (error is usually between 0 and 10)*/
  485.                 float cr = (1.0f + 2 * _error) / (2.0f + _error);
  486.                 float cr3 = cr * cr * cr;
  487.                 cr = cr * (cr3 + 2.0f * _error) / (2.0f * cr3 + _error);
  488.                 _dt *= 0.9f / cr;
  489.                 /* _dt *= 0.9f * rootn(_error, -3.0f); */
  490.             }
  491.  
  492.             /* Detect thresholds */
  493.             thresh_0_state = !thresh_0_pre_state && (V > 0);
  494.             _any_thresh |= thresh_0_state;
  495.             thresh_1_state = !thresh_1_pre_state && (V > 0);
  496.             _any_thresh |= thresh_1_state;
  497.  
  498.             if(!syn_thresh_0_pre_state && (V > 0))
  499.             {
  500.                 _any_thresh = true;
  501.                 _syn_thresh_0_arr[atomic_inc(&_syn_thresh_0_num)] = i;
  502.             }
  503.  
  504.             /* Check exit condition */
  505.             if(_cur_time >= timestep)
  506.                 atomic_inc(&_num_complete);
  507.         }
  508.  
  509.         /* Handle thresholds */
  510.         if(thresh_0_state)
  511.         {
  512.             float delay = 1.0f;
  513.             V = c;
  514.             u += d;
  515.             delay = axon_delay;
  516.             record(4, 0);
  517.             _dt = 0.1f;
  518.             int _idx_idx = 1 * i + 0;
  519.             int _buff_start = _circ_buffer_start[_idx_idx];
  520.  
  521.             if(_buff_start != _circ_buffer_end[_idx_idx])
  522.             {
  523.                 const int _circ_buffer_size = 150;
  524.                 int _end_idx;
  525.  
  526.                 if(_buff_start < 0) //It is empty
  527.                 {
  528.                     _circ_buffer_start[_idx_idx] = 0;
  529.                     _circ_buffer_end[_idx_idx] = 1;
  530.                     _end_idx = 1;
  531.                 }
  532.                 else
  533.                 {
  534.                     _end_idx = _circ_buffer_end[_idx_idx] = (_circ_buffer_end[_idx_idx] + 1) % _circ_buffer_size;
  535.                 }
  536.  
  537.                 int _buff_idx = (i * 1 + 0) * _circ_buffer_size + _end_idx - 1;
  538.                 _circ_buffer[_buff_idx] = t + delay;
  539.             }
  540.             else //It is full, error
  541.             {
  542.                 _error_buffer[i + 1] = 2 + 0;
  543.             }
  544.         }
  545.  
  546.         if(thresh_1_state)
  547.         {
  548.             stdp_o1 = stdp_o1 * exp(-(t - stdp_last_o_time) / stdp_o1_tau) + 1;
  549.             stdp_o2 = stdp_o2 * exp(-(t - stdp_last_o_time) / stdp_o2_tau) + 1;
  550.             stdp_last_o_time = t;
  551.         }
  552.  
  553.         barrier(CLK_LOCAL_MEM_FENCE);
  554.  
  555.         if(_syn_thresh_0_num > 0)
  556.         {
  557.             for(int _ii = 0; _ii < _syn_thresh_0_num; _ii++)
  558.             {
  559.                 int nrn_id = _syn_thresh_0_arr[_ii];
  560.                 /* Declare locals */
  561.                 __local float stdp_s_local;
  562.                 __local float stdp_o1_local;
  563.                 __local float stdp_o2_local;
  564.                 __local float stdp_active_local;
  565.  
  566.                 /* Init locals */
  567.                 if(i == nrn_id)
  568.                 {
  569.                     stdp_s_local = stdp_s;
  570.                     stdp_o1_local = stdp_o1;
  571.                     stdp_o2_local = stdp_o2;
  572.                     stdp_active_local = stdp_active;
  573.                 }
  574.  
  575.                 barrier(CLK_LOCAL_MEM_FENCE);
  576.                 int _syn_offset = nrn_id * 500;
  577.  
  578.                 for(int _g_syn_i = _syn_offset + _local_id; _g_syn_i < 500 + _syn_offset; _g_syn_i += _local_size)
  579.                 {
  580.                     /* Load syn globals */
  581.                     float stdp_weight = _stdp_weight_buf[_g_syn_i];
  582.                     float stdp_r1 = _stdp_r1_buf[_g_syn_i];
  583.                     float stdp_r2 = _stdp_r2_buf[_g_syn_i];
  584.                     float stdp_last_r_time = _stdp_last_r_time_buf[_g_syn_i];
  585.                     /* Thresh source */
  586.                     stdp_r1 = stdp_r1 * exp(-(t - stdp_last_r_time) / stdp_r1_tau);
  587.                     stdp_r2 = stdp_r2 * exp(-(t - stdp_last_r_time) / stdp_r2_tau);
  588.                     stdp_last_r_time = t;
  589.  
  590.                     if(stdp_active_local > 0)
  591.                         stdp_weight += stdp_Ap * (stdp_r1 - stdp_r2) * (stdp_o1_local - 1.0 - (stdp_o2_local - 1.0)); // Take the o'stdp_s_local from before the update
  592.  
  593.                     /* Save syn globals */
  594.                     _stdp_weight_buf[_g_syn_i] = stdp_weight;
  595.                     _stdp_r1_buf[_g_syn_i] = stdp_r1;
  596.                     _stdp_r2_buf[_g_syn_i] = stdp_r2;
  597.                     _stdp_last_r_time_buf[_g_syn_i] = stdp_last_r_time;
  598.                 }
  599.             }
  600.         }
  601.     }
  602.  
  603.     if(_dt_residual > 0.1f)
  604.         _dt = _dt_residual;
  605.  
  606.     if(_dt > timestep)
  607.         _dt = timestep;
  608.  
  609.     _dt_buf[i] = _dt;
  610.     _rwvalues_0_buf[i] = (float)(V);
  611.     _rwvalues_1_buf[i] = (float)(u);
  612.     _rwvalues_2_buf[i] = (float)(rand_syn_r);
  613.     _rwvalues_3_buf[i] = (float)(static_excitatory_s1);
  614.     _rwvalues_4_buf[i] = (float)(static_excitatory_s2);
  615.     _rwvalues_5_buf[i] = (float)(static_inhibitory_s);
  616.     _rwvalues_6_buf[i] = (float)(stdp_s);
  617.     _rwvalues_7_buf[i] = (float)(stdp_o1);
  618.     _rwvalues_8_buf[i] = (float)(stdp_o2);
  619.     _rwvalues_9_buf[i] = (float)(stdp_last_o_time);
  620.     _rwvalues_10_buf[i] = (float)(stdp_active);
  621.     _rand_state_buf[i] = _rand_state;
  622.     barrier(CLK_LOCAL_MEM_FENCE);
  623.  
  624.     if(_local_id == 0)
  625.     {
  626.         _record_idx[_group_id] = _local_record_idx;
  627.     }
  628. }
Add Comment
Please, Sign In to add comment