Advertisement
here2share

# swap_closest_weight.py

Mar 30th, 2023
486
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.32 KB | None | 0 0
  1. # swap_closest_weight.py
  2.  
  3. def swap_closest_weight(weights, incorrect_pred):
  4.     """
  5.    Swap an incorrect weight with the weight closest to it which is above a certain threshold
  6.    """
  7.     threshold = 5 # Set a threshold for weight updates
  8.     delta = float('inf')
  9.     current_weight = weights[incorrect_pred]
  10.     for weight in weights:
  11.         if weights[weight] > threshold and weight != incorrect_pred:
  12.             new_delta = abs(weights[weight] - current_weight)
  13.             if new_delta < delta:
  14.                 delta = new_delta
  15.                 closest_weight = weight
  16.     weights[incorrect_pred] -= 1
  17.     weights[closest_weight] += 1
  18.  
  19. def guess(data, weights):
  20.     """
  21.    Returns the predicted label and whether it is correct, given the input data
  22.    """
  23.     percentages = {}
  24.     for i in range(10):
  25.         total_weight = sum([weights[node, str(i)] for node in nodes if str(i) in node])
  26.         percentages[str(i)] = total_weight
  27.     try:
  28.         total = 100 / sum(percentages.values())
  29.     except:
  30.         return False, 'X'
  31.     percentages = {k: round(v*total, 6) for k, v in percentages.items()}
  32.     sorted_ppp = sorted(percentages.items(), key=lambda x: x[1], reverse=True)
  33.     ppp = sorted_ppp[0][0]
  34.     correct = data[-1] == ppp
  35.     if not correct:
  36.         swap_closest_weight(weights, weights[data[-1]])
  37.     return correct, ppp
  38.  
  39. def pattern_recognition(pi):
  40.     """
  41.    Perform pattern recognition on a string of digits
  42.    """
  43.     weights = {(node, i): 1 for node in nodes for i in range(10)}
  44.     right_answers = [(pi[i], str(i).zfill(4)) for i in range(len(pi))]
  45.     random.shuffle(right_answers)
  46.     correct_count = 0
  47.     incorrect_count = 0
  48.     threshold = 0.1
  49.     learning_rate = 0.1
  50.     for i, (digit, data) in enumerate(right_answers):
  51.         prediction_correct, prediction = guess(data, weights)
  52.         if prediction_correct:
  53.             correct_count += 1
  54.             incorrect_count = 0
  55.         else:
  56.             incorrect_count += 1
  57.             if incorrect_count > 1 and random.random() > threshold:
  58.                 swap_closest_weight(weights, weights[digit])
  59.             else:
  60.                 weights[digit] -= learning_rate
  61.                 weights[prediction] += learning_rate
  62.         if correct_count + incorrect_count == len(right_answers):
  63.             break
  64.     return weights
  65.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement