Advertisement
Guest User

Temporal Pooler

a guest
Feb 8th, 2023
73
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 7.63 KB | None | 0 0
  1. import cupy as cp
  2. import numpy as np
  3. import copy
  4. import math
  5.  
  6. # Code adapted from
  7. # https://github.com/numenta/htmresearch/blob/master/htmresearch/algorithms/union_temporal_pooler.py
  8. class TemporalPooler:
  9.     def __init__(self, input_size, cells,
  10.                sparsity=0.05,
  11.                activeOverlapWeight=1.0,
  12.                predictedActiveOverlapWeight=10.0,
  13.                maxUnionActivity=0.20,
  14.                decayTimeConst=20.0,
  15.                synPermPredActiveInc=0.3,
  16.                synPermPreviousPredActiveInc=0.3,
  17.                historyLength=3,
  18.                minHistory=0):
  19.        
  20.         self.input_size = input_size
  21.         self.cells = cells
  22.         self.sparsity = sparsity
  23.         self.duty_cycle = cp.zeros(self.cells, dtype=np.float32)
  24.         self.boosting_intensity = 0.3
  25.         self.duty_cycle_inertia = 0.99
  26.  
  27.         self.permanence = cp.random.randn(self.cells, self.input_size)
  28.         self.active_overlap_weight = activeOverlapWeight
  29.         self.predicted_active_overlap_weight = predictedActiveOverlapWeight
  30.         self.max_union_activity = maxUnionActivity
  31.        
  32.         self.permanence_threshold = 0.0
  33.        
  34.         self.syn_perm_active_inc = 0.2
  35.         self.syn_perm_inactive_dec = 0.3
  36.         self.syn_perm_pred_active_inc = synPermPredActiveInc
  37.         self.syn_perm_previous_pred_active_inc = synPermPreviousPredActiveInc
  38.         self.history_length = historyLength
  39.         self.min_history = minHistory
  40.        
  41.         self.max_union_cells = int(cells * self.max_union_activity)
  42.         self.pooling_activation = cp.zeros(cells, dtype=np.float32)
  43.  
  44.         self.pooling_timer = cp.ones(cells, dtype=np.float32) * 1000
  45.         self.pooling_activation_init_level = cp.zeros(cells, dtype=np.float32)
  46.         self.pooling_activation_tie_breaker = cp.random.randn(cells) * 0.000001
  47.         self.union_SDR = cp.array([], dtype=np.int32)
  48.         self.active_cells = cp.array([], dtype=np.int32)
  49.        
  50.         self.pred_active_input = cp.zeros(input_size, dtype=np.float32)
  51.         self.pre_predicted_active_input = cp.zeros((input_size, self.history_length), dtype=np.int32)
  52.         self.prev_active_cells = cp.zeros(cells, dtype=np.int32)
  53.  
  54.     def run(self, active_input, predicted_active_input, train=True):
  55.         weight = self.permanence > self.permanence_threshold
  56.        
  57.         # Compute proximal dendrite overlaps with active and active-predicted inputs
  58.         overlaps_active = cp.sum(active_input & weight, axis=1)
  59.         overlaps_predicted_active = cp.sum(predicted_active_input & weight, axis=1)
  60.        
  61.         total_overlap = (overlaps_active * self.active_overlap_weight +
  62.                         overlaps_predicted_active *
  63.                         self.predicted_active_overlap_weight).astype(np.float32)
  64.        
  65.         #perform global inhibition
  66.         boosting = cp.exp(self.boosting_intensity * -self.duty_cycle / self.sparsity)
  67.         total_overlap *= boosting
  68.        
  69.         k_winners = int(self.sparsity * self.cells)
  70.         self.active_cells = total_overlap.argsort()[-k_winners:]
  71.        
  72.         # Decrement pooling activation of all cells
  73.         self.decay_pooling_activation()
  74.    
  75.         # Update the poolingActivation of current active Union Temporal Pooler cells
  76.         self.add_to_pooling_activation(self.active_cells, overlaps_predicted_active)
  77.    
  78.         # update union SDR
  79.         self.get_most_active_cells()
  80.        
  81.         if train:
  82.             # Adjust boosting factor
  83.             self.duty_cycle *= self.duty_cycle_inertia
  84.             self.duty_cycle[self.active_cells] += 1.0 - self.duty_cycle_inertia
  85.            
  86.  
  87.             # adapt permanence of connections from predicted active inputs to newly active cell
  88.             # This step is the spatial pooler learning rule, applied only to the predictedActiveInput
  89.             # Todo: should we also include unpredicted active input in this step?
  90.             #self.adapt_synapses(predicted_active_input, active_cells, self.syn_perm_active_inc, self.syn_perm_inactive_dec)
  91.             self.permanence[self.active_cells] += predicted_active_input * (self.syn_perm_active_inc + self.syn_perm_inactive_dec) - self.syn_perm_inactive_dec
  92.              
  93.             # Increase permanence of connections from predicted active inputs to cells in the union SDR
  94.             # This is Hebbian learning applied to the current time step
  95.             #self.adapt_synapses(predicted_active_input, self.union_sdr, self.syn_perm_pred_active_inc, 0.0)
  96.             self.permanence[self.union_SDR] += predicted_active_input * self.syn_perm_pred_active_inc
  97.  
  98.             # adapt permenence of connections from previously predicted inputs to newly active cells
  99.             # This is a reinforcement learning rule that considers previous input to the current cell
  100.             for i in range(self.history_length):
  101.                 self.permanence[self.active_cells] += self.pre_predicted_active_input[:,i] * self.syn_perm_previous_pred_active_inc
  102.        
  103.         #Save previous inputs
  104.         self.pre_active_input = copy.copy(active_input)
  105.         self.pre_predicted_active_input = cp.roll(self.pre_predicted_active_input, 1, 1)
  106.         if self.history_length > 0:
  107.             self.pre_predicted_active_input[:,0] = predicted_active_input
  108.        
  109.         cp.cuda.Stream.null.synchronize()
  110.         return self.union_SDR
  111.    
  112.     def decay_pooling_activation(self):
  113.         """
  114.        Decrements pooling activation of all cells
  115.        """
  116.         #exponential decay
  117.         self.pooling_activation = cp.exp(-0.1 * self.pooling_timer) *  self.pooling_activation_init_level
  118.        
  119.         #no decay
  120.         #self.pooling_activation = self.pooling_activation_init_level
  121.        
  122.         return self.pooling_activation
  123.    
  124.     def add_to_pooling_activation(self, active_cells, overlaps):
  125.         """
  126.        Adds overlaps from specified active cells to cells' pooling
  127.        activation.
  128.        @param activeCells: Indices of those cells winning the inhibition step
  129.        @param overlaps: A current set of overlap values for each cell
  130.        @return current pooling activation
  131.        """
  132.         # Sigmoid activation
  133.         """baseLinePersistence = 10
  134.        extraPersistence = 10
  135.        thresh = 5
  136.        self.pooling_activation[active_cells] = baseLinePersistence + extraPersistence/(1 + cp.exp(-(overlaps[active_cells] - thresh)))  """
  137.        
  138.         self.pooling_activation[active_cells] += overlaps[active_cells]
  139.    
  140.        
  141.         self.pooling_timer[self.pooling_timer >= 0] += 1
  142.        
  143.         self.pooling_timer[active_cells] = 0
  144.         self.pooling_activation_init_level[active_cells] = self.pooling_activation[active_cells]
  145.        
  146.         return self.pooling_activation
  147.    
  148.     def get_most_active_cells(self):
  149.         """
  150.        Gets the most active cells in the Union SDR having at least non-zero
  151.        activation in sorted order.
  152.        @return: a list of cell indices
  153.        """
  154.         pooling_activation = self.pooling_activation
  155.         non_zero_cells = cp.nonzero(pooling_activation)[0]
  156.          
  157.         # include a tie-breaker before sorting
  158.         pooling_activation_subset = pooling_activation[non_zero_cells] + \
  159.                                   self.pooling_activation_tie_breaker[non_zero_cells]
  160.  
  161.         sorted_cells = non_zero_cells[pooling_activation_subset[non_zero_cells].argsort()[::-1]]
  162.         top_cells = sorted_cells[:self.max_union_cells]
  163.  
  164.         if max(self.pooling_timer) > self.min_history:
  165.             self.union_SDR = cp.sort(top_cells).astype(np.int32)
  166.         else:
  167.             self.union_SDR = []
  168.            
  169.         return self.union_SDR
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement