Advertisement
Guest User

Untitled

a guest
Jun 8th, 2018
83
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 8.46 KB | None | 0 0
  1. ''' --- TEMPORAL MEMORY FUNCTIONS --- '''
  2. def cellsFromCol(col):
  3.     global TM_cellsPerCol
  4.     minCol = (col*TM_cellsPerCol)+1
  5.     cells = [i+minCol for i in range(TM_cellsPerCol)]
  6.     return cells
  7.  
  8. def colFromCell(cell):
  9.     global TM_cellsPerCol
  10.     col = int((float(cell) / TM_cellsPerCol) - 0.01)
  11.     return col
  12.  
  13. def segmentsFromCell(cell,cells_segments):
  14.     segs = cells_segments[cell]
  15.     return segs
  16.  
  17. def segmentsFromCol(col,cells_segments):
  18.     cells_Segs = {}
  19.     cells = cellsFromCol(col)
  20.     for c in cells:
  21.         cells_Segs[c] = segmentsFromCell(c,cells_segments)
  22.     return cells_Segs
  23.  
  24. def activatePredColCells(col,prevActiveCells,prevWinnerCells,colActiveCells,colMatchingCells,cells_segments):
  25.     global TM_segActiveThresh,TM_permINC,TM_permDEC,TM_segNewSynCount,TM_learn
  26.     activeCs = list()
  27.     winnerCs = list()    
  28.     for cell in colMatchingCells:
  29.         activeCs.append(cell)
  30.         winnerCs.append(cell)
  31.         if TM_learn:
  32.             segList = cells_segments[cell]
  33.             for seg in segList:
  34.                 segIndex = segList.index(seg)
  35.                 if len(activeSynapsesOnSegment(seg,prevActiveCells)) > TM_segActiveThresh:
  36.                     for preSynCell,perm in seg.items():
  37.                         if preSynCell in prevActiveCells:
  38.                             if (perm + TM_permINC) > 1.0:
  39.                                 cells_segments[cell][segIndex][preSynCell] = 1.0
  40.                             else:
  41.                                 cells_segments[cell][segIndex][preSynCell] += TM_permINC
  42.                         else:
  43.                             if (perm - TM_permDEC) < 0:
  44.                                 cells_segments[cell][segIndex][preSynCell] = 0
  45.                             else:
  46.                                 cells_segments[cell][segIndex][preSynCell] -= TM_permDEC              
  47.                     key = str(cell) + '_' + str(segIndex)
  48.                     newSynCount = TM_segNewSynCount - cellsSegInds_numActivePotentialSyns[key]
  49.                     growSynapses(seg,cell,newSynCount,prevWinnerCells)
  50.     return activeCs,winnerCs,cells_segments
  51.  
  52. def activeSynapsesOnSegment(segment,prevActiveCells):
  53.     preSynCells = list(segment.keys())
  54.     actives = [c for c in preSynCells if c in prevActiveCells]
  55.     return actives
  56.  
  57. def growSynapses(segment,cell,newSynapseCount,prevWinnerCells):
  58.     global TM_segNewSynCount,TM_permINIT
  59.     candidates = copy.copy(prevWinnerCells)
  60.     while len(candidates) > 0 and newSynapseCount > 0:
  61.         presynapticCell = rd.choice(candidates)
  62.         candidates.remove(presynapticCell)
  63.         alreadyConnected = False
  64.         for preSynCell,perm in segment.items():
  65.             if preSynCell == presynapticCell:
  66.                 alreadyConnected = True
  67.         if not alreadyConnected:
  68.             createNewSynapse(segment,cell,presynapticCell,TM_permINIT)
  69.             newSynapseCount -= 1
  70.  
  71. def createNewSynapse(segment,cell,presynapticCell,TM_permINIT):
  72.     global cells_segments
  73.     segDicts = cells_segments[cell]
  74.     seg_index = segDicts.index(segment)
  75.     cells_segments[cell][seg_index][presynapticCell] = TM_permINIT
  76.  
  77. def burstCol(col,prevActiveCells,prevWinnerCells,activeSegCells,matchingSegCells,cells_segments,cellsSegInds_numActivePotentialSyns):
  78.     global TM_segNewSynCount
  79.     activeCs = list()    
  80.     winnerCs = list()
  81.     for cell in cellsFromCol(col):
  82.         activeCs.append(cell)    
  83.     if len(matchingSegCells) > 0:
  84.         learningSegment,winnerCell = bestMatchingSegmentAndCell(col,matchingSegCells,prevWinnerCells,cells_segments,cellsSegInds_numActivePotentialSyns)
  85.     else:
  86.         winnerCell = leastUsedCell(col,cells_segments)
  87.         if TM_learn and len(prevWinnerCells) > 0:
  88.             learningSegment,cells_segments = growNewSegment(winnerCell,prevWinnerCells,prevActiveCells,cells_segments)
  89.     winnerCs.append(winnerCell)
  90.     if TM_learn and len(prevWinnerCells) > 0:
  91.         segIndex = cells_segments[winnerCell].index(learningSegment)
  92.         for preSynCell,perm in learningSegment.items():
  93.             if preSynCell in prevActiveCells:
  94.                 if (perm + TM_permINC) > 1.0:
  95.                     cells_segments[winnerCell][segIndex][preSynCell] = 1.0
  96.                 else:
  97.                     cells_segments[winnerCell][segIndex][preSynCell] += TM_permINC
  98.             else:
  99.                 if (perm - TM_permDEC) < 0:
  100.                     cells_segments[winnerCell][segIndex][preSynCell] = 0
  101.                 else:
  102.                     cells_segments[winnerCell][segIndex][preSynCell] -= TM_permDEC                  
  103.         key = str(cell) + '_' + str(segIndex)
  104.         numActiveSynsOnSeg = 0
  105.         if key in cellsSegInds_numActivePotentialSyns:
  106.             numActiveSynsOnSeg = cellsSegInds_numActivePotentialSyns[key]
  107.         newSynapseCount = TM_segNewSynCount - numActiveSynsOnSeg
  108.         growSynapses(learningSegment,winnerCell,newSynapseCount,prevWinnerCells)
  109.     return activeCs,winnerCs,cells_segments
  110.  
  111. def getActiveMatchingCellsFromCol(col,prevActiveCells,cells_segments):
  112.     global TM_permConnectThresh
  113.     cells = cellsFromCol(col)
  114.     activeSegCells, matchingSegCells = [[],[]]
  115.     for c in cells:
  116.         if len(cells_segments[c]) > 0:
  117.             for seg in cells_segments[c]:
  118.                 connectedSyns = []        
  119.                 for preSynCell, perm in seg.items():        
  120.                     if perm >= TM_permConnectThresh:            
  121.                         connectedSyns.append(preSynCell)            
  122.                 connected_prevActive_syns = [cell for cell in connectedSyns if cell in prevActiveCells]
  123.                 if len(connected_prevActive_syns) >= TM_segActiveThresh:
  124.                     activeSegCells.append(c)
  125.                 if len(connected_prevActive_syns) >= TM_segMatchingThres:
  126.                     matchingSegCells.append(c)
  127.     return activeSegCells,matchingSegCells
  128.  
  129. def bestMatchingSegmentAndCell(col,matchingCells,prevActiveCells,cells_segments,cellsSegInds_numActivePotentialSyns):
  130.     bestMatchSeg = None
  131.     bestMatchCell = None
  132.     bestScore = -1
  133.     for cell in matchingCells:
  134.         segDicts = cells_segments[cell]            
  135.         for seg in segDicts:
  136.             segIndex = segDicts.index(seg)
  137.             key = str(cell) + '_' + str(segIndex)
  138.             numActiveSyns = 0
  139.             if key in cellsSegInds_numActivePotentialSyns:
  140.                 numActiveSyns = cellsSegInds_numActivePotentialSyns[key]
  141.             if numActiveSyns > bestScore:
  142.                 bestMatchSeg = seg
  143.                 bestMatchCell = cell
  144.                 bestScore = numActiveSyns
  145.     return bestMatchSeg,bestMatchCell
  146.    
  147. def leastUsedCell(col,cells_segments):
  148.     global TM_cellMaxSegs
  149.     cells = cellsFromCol(col)
  150.     leastSegCells = list()
  151.     minSegs = TM_cellMaxSegs
  152.     for c in cells:                
  153.         seg_count = len(cells_segments[c])
  154.         minSegs = min(minSegs,seg_count)
  155.     for c in cells:
  156.         if len(cells_segments[c]) == minSegs:
  157.             leastSegCells.append(c)
  158.     return rd.choice(leastSegCells)
  159.  
  160. def growNewSegment(cell,prevCells,prevActives,cells_segments):
  161.     global TM_segNewSynCount,TM_permINIT
  162.     if len(prevCells) < len(prevActives):
  163.         prevCells = prevActives
  164.     preSynCells = rd.sample(prevCells,TM_segNewSynCount)
  165.     seg = {}
  166.     for c in preSynCells:
  167.         seg[c] = TM_permINIT
  168.     cells_segments[cell].append(seg)
  169.     return seg,cells_segments    
  170.  
  171. def decSeg(seg,cell,segIndex,prevActiveCells,cells_segments):
  172.     for preSynCell,perm in seg.items():
  173.         if preSynCell in prevActiveCells:
  174.             currentPerm = cells_segments[cell][segIndex][preSynCell]
  175.             if (currentPerm - TM_permDEC) < 0:    
  176.                 cells_segments[cell][segIndex][preSynCell] = 0
  177.             else:
  178.                 cells_segments[cell][segIndex][preSynCell] -= TM_permDEC
  179.     return cells_segments
  180.  
  181. def decColSegments(col,prevActiveCells,matchingSegCells,cells_segments):
  182.     global TM_segMatchingThres,TM_learn
  183.     if TM_learn:
  184.         for cell in matchingSegCells:
  185.             segList = cells_segments[cell]
  186.             for seg in segList:
  187.                 segIndex = segList.index(seg)
  188.                 segPreSynCells = list(seg.keys())
  189.                 numActivePotentialSyns = len([c for c in segPreSynCells if c in prevActiveCells])
  190.                 if numActivePotentialSyns > TM_segMatchingThres:
  191.                     cells_segments = decSeg(seg,cell,segIndex,prevActiveCells,cells_segments)
  192.     return cells_segments
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement