Advertisement
toweber

nengo-learning_test

Sep 19th, 2022
760
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.62 KB | None | 0 0
  1. # based on https://www.nengo.ai/nengo/examples/learning/learn-square.html
  2.  
  3. import matplotlib.pyplot as plt
  4. import numpy as np
  5. import nengo
  6. from nengo.dists import Uniform
  7. from nengo.utils.matplotlib import rasterplot
  8.  
  9.  
  10. model = nengo.Network()
  11.  
  12. with model:
  13.     # Create the ensemble to represent the input, the input squared (learned),
  14.     # and the error
  15.     A = nengo.Ensemble(100, dimensions=1)
  16.     A_squared = nengo.Ensemble(100, dimensions=1)
  17.     error = nengo.Ensemble(100, dimensions=1)
  18.  
  19.     #****************        
  20.     # PROVIDE INPUT
  21.     #****************
  22.     #input_signal = nengo.Node(lambda t: np.cos(8 * t))
  23.     #input_a = nengo.Node(output=0.5)
  24.     #input_b = nengo.Node(output=0.3)
  25.     input_node = nengo.Node(output=lambda t: int(6 * t / 5) / 3.0 % 2 - 1)
  26.  
  27.     #****************
  28.     # CONNECT ELEMENTS
  29.     #****************
  30.     #Connect the input signal to the neuron
  31.     # The indices in neurons define which dimension the input will project to
  32.  
  33.     #nengo.Connection(input_signal, A, synapse=0.01)
  34.     #nengo.Connection(sin, neurons[0])
  35.     #nengo.Connection(cos, neurons[1])
  36.  
  37.     # Connect the input node to ensemble A
  38.     nengo.Connection(input_node, A)
  39.  
  40.     # Connect A and A_squared with a communication channel
  41.     conn = nengo.Connection(A, A_squared)
  42.  
  43.     # learning related connections
  44.     # the part below
  45.     # Compute the error signal (error = actual - target)
  46.     # PS: it seems this lines makes: error = A_squared
  47.     nengo.Connection(A_squared, error)
  48.  
  49.     # Subtract the target (this would normally come from some external system)
  50.     # PS: it seems this lines makes: error += (-A**2)
  51.     nengo.Connection(A, error, function=lambda x: x**2, transform=-1)
  52.  
  53.     #****************
  54.     # LEARNING
  55.     #****************
  56.  
  57.     # Apply the PES learning rule to conn
  58.     conn.learning_rule_type = nengo.PES(learning_rate=3e-4)
  59.  
  60.     # Provide an error signal to the learning rule
  61.     nengo.Connection(error, conn.learning_rule)
  62.  
  63.     # Shut off learning by inhibiting the error population
  64.     stop_learning = nengo.Node(output=lambda t: t >= 15)
  65.     nengo.Connection(
  66.         stop_learning, error.neurons, transform=-20 * np.ones((error.n_neurons, 1))
  67.     )
  68.  
  69.  
  70.     #****************
  71.     # PROBES
  72.     #****************
  73.     # The original input
  74.     #input_signal_probe = nengo.Probe(input_signal)
  75.     input_node_probe = nengo.Probe(input_node)
  76.     A_probe = nengo.Probe(A, synapse=0.01)
  77.     A_squared_probe = nengo.Probe(A_squared, synapse=0.01)
  78.     error_probe = nengo.Probe(error, synapse=0.01)
  79.     learn_probe = nengo.Probe(stop_learning, synapse=None)
  80.  
  81. with nengo.Simulator(model) as sim:  # Create the simulator
  82.     #****************
  83.     # Run
  84.     #****************
  85.     sim.run(20)  # seconds
  86.  
  87. #****************
  88. # plot the results
  89. #****************
  90. t =  sim.trange()
  91.  
  92. # Plot the input signal
  93. plt.figure(figsize=(9, 9))
  94. plt.subplot(3, 1, 1)
  95. plt.plot(
  96.     sim.trange(), sim.data[input_node_probe], label="Input", color="k", linewidth=2.0
  97. )
  98. plt.plot(
  99.     sim.trange(),
  100.     sim.data[learn_probe],
  101.     label="Stop learning?",
  102.     color="r",
  103.     linewidth=2.0,
  104. )
  105. plt.legend(loc="lower right")
  106. plt.ylim(-1.2, 1.2)
  107.  
  108. plt.subplot(3, 1, 2)
  109. plt.plot(
  110.     sim.trange(), sim.data[input_node_probe] ** 2, label="Squared Input", linewidth=2.0
  111. )
  112. plt.plot(sim.trange(), sim.data[A_squared_probe], label="Decoded Ensemble $A^2$")
  113. plt.legend(loc="lower right")
  114. plt.ylim(-1.2, 1.2)
  115.  
  116. plt.subplot(3, 1, 3)
  117. plt.plot(
  118.     sim.trange(),
  119.     sim.data[A_squared_probe] - sim.data[input_node_probe] ** 2,
  120.     label="Error",
  121. )
  122. plt.legend(loc="lower right")
  123. plt.tight_layout()
  124. plt.show()
  125.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement