Advertisement
here2share

# hopfield_solution.py

Mar 30th, 2023
586
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.66 KB | None | 0 0
  1. # hopfield_solution.py
  2.  
  3. import pdb
  4. import math
  5. import random
  6. import os, sys
  7. import ast
  8.  
  9. def print_each_at_x(vars_list, nnn=22):
  10.     sss = ''
  11.     for var in vars_list:
  12.         n = nnn
  13.         if isinstance(var, list):
  14.             var = ' '.join([str(v) for v in var])
  15.         else:
  16.             var = str(var)
  17.         if len(var) > n - 4:
  18.             n *= 2
  19.         var = (var + ' ' * n)[:n-1] + ' '
  20.         sss += var
  21.     print(sss.rstrip())
  22.  
  23. # Hyperparameters
  24. LEARNING_RATE = 0.0001
  25. EPOCHS = 200
  26. THRESHOLD = 0.8
  27.  
  28. # Direct updates towards the correct prediction
  29. def direct_update(weights, node, pi_digit, ppp):
  30.     delta = (1.0 - weights[node, pi_digit] * (weights[node, ppp] - (pi_digit == ppp))) * LEARNING_RATE
  31.     weights[node, ppp] += delta
  32.     return weights
  33.  
  34. # Update the weights using the heaviest weight swapping
  35. def update_weight(weights, nodes, pi_digit, data, ppp, recall_pattern):
  36.     recall_pattern.append(ppp)
  37.     if ppp != pi_digit and len(recall_pattern) >= 2:
  38.         heaviest_node = None
  39.         max_weight = -1
  40.         for node in nodes:
  41.             if node[0] in recall_pattern and node[1] not in recall_pattern:
  42.                 if weights[node, pi_digit] > max_weight:
  43.                     max_weight = weights[node, pi_digit]
  44.                     heaviest_node = node
  45.         if heaviest_node != None:
  46.             weights = direct_update(weights, heaviest_node, pi_digit, ppp)
  47.         recall_pattern.pop(0)
  48.     return weights, recall_pattern
  49.  
  50. # Initialize dictionary for defining weights and targets
  51. def initialize_dict(digits, nodes):
  52.     weights = {}
  53.     target = {}
  54.     for digit in digits:
  55.         for node in nodes:
  56.             weights[node, digit] = 1
  57.             target[node, digit] = 0
  58.     return weights, target
  59.  
  60. def guess(data, weights, nodes, target):
  61.     ddd = {str(i): 0 for i in range(10)}
  62.     percentages = {}
  63.     for node in nodes:
  64.         ddd[str(target[node, pi_digit])] += weights[node, pi_digit]
  65.         for i in range(10):
  66.             percentages[str(i)] = round(ddd[str(i)] * (100.0/sum(ddd.values())), 6)
  67.     ppp = max(percentages, key=percentages.get)
  68.     if percentages[ppp] < THRESHOLD*100:
  69.         ppp = 'X'
  70.     prediction = ppp == pi_digit
  71.     return prediction, ppp
  72.  
  73. def train(weights, target, nodes, recall_pattern):
  74.     right = 0
  75.     wrong = 0
  76.     for epoch in range(EPOCHS):
  77.         random.shuffle(data)
  78.         for digit, pixels in data:
  79.             pi_digit = str(digit)
  80.             recall_pattern = []
  81.             for i, pixel in enumerate(pixels):
  82.                 ppp = '1' if pixel else '-1'
  83.                 weights, recall_pattern = update_weight(weights, nodes, pi_digit, data, ppp, recall_pattern)
  84.  
  85.             prediction, ppp = guess(data, weights, nodes, target)
  86.             if prediction:
  87.                 right += 1
  88.             else:
  89.                 wrong += 1
  90.  
  91.         if epoch % 10 == 0:
  92.             print_each_at_x(["epoch", epoch, "accuracy", round(100.0*right/(right+wrong),2), "X=", ppp, "perc", percentages])
  93.             sys.stdout.flush()
  94.  
  95. digits = [str(i) for i in range(10)]
  96. nodes = [(i, j) for i in range(28*28) for j in digits]
  97.  
  98. weights, target = initialize_dict(digits, nodes)
  99.  
  100. # Train the network
  101. train(weights, target, nodes, [])
  102.  
  103. # Save the trained weights to a file
  104. with open('weights.txt', 'w') as f:
  105.     f.write(str(weights))
  106.  
  107. # Load the trained weights from a file
  108. with open('weights.txt', 'r') as f:
  109.     weights = ast.literal_eval(f.read())
  110.  
  111. # Example usage:
  112. # test_data = [(8, [0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1])]
  113. # prediction, ppp = guess(test_data, weights, nodes, target)
  114. # print(prediction, ppp)
  115.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement