Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # based on https://www.nengo.ai/nengo/examples/learning/learn-square.html
- import matplotlib.pyplot as plt
- import numpy as np
- import nengo
- from nengo.dists import Uniform
- from nengo.utils.matplotlib import rasterplot
- model = nengo.Network()
- with model:
- # Create the ensemble to represent the input, the input squared (learned),
- # and the error
- A = nengo.Ensemble(100, dimensions=1)
- A_squared = nengo.Ensemble(100, dimensions=1)
- error = nengo.Ensemble(100, dimensions=1)
- #****************
- # PROVIDE INPUT
- #****************
- #input_signal = nengo.Node(lambda t: np.cos(8 * t))
- #input_a = nengo.Node(output=0.5)
- #input_b = nengo.Node(output=0.3)
- input_node = nengo.Node(output=lambda t: int(6 * t / 5) / 3.0 % 2 - 1)
- #****************
- # CONNECT ELEMENTS
- #****************
- #Connect the input signal to the neuron
- # The indices in neurons define which dimension the input will project to
- #nengo.Connection(input_signal, A, synapse=0.01)
- #nengo.Connection(sin, neurons[0])
- #nengo.Connection(cos, neurons[1])
- # Connect the input node to ensemble A
- nengo.Connection(input_node, A)
- # Connect A and A_squared with a communication channel
- conn = nengo.Connection(A, A_squared)
- # learning related connections
- # the part below
- # Compute the error signal (error = actual - target)
- # PS: it seems this lines makes: error = A_squared
- nengo.Connection(A_squared, error)
- # Subtract the target (this would normally come from some external system)
- # PS: it seems this lines makes: error += (-A**2)
- nengo.Connection(A, error, function=lambda x: x**2, transform=-1)
- #****************
- # LEARNING
- #****************
- # Apply the PES learning rule to conn
- conn.learning_rule_type = nengo.PES(learning_rate=3e-4)
- # Provide an error signal to the learning rule
- nengo.Connection(error, conn.learning_rule)
- # Shut off learning by inhibiting the error population
- stop_learning = nengo.Node(output=lambda t: t >= 15)
- nengo.Connection(
- stop_learning, error.neurons, transform=-20 * np.ones((error.n_neurons, 1))
- )
- #****************
- # PROBES
- #****************
- # The original input
- #input_signal_probe = nengo.Probe(input_signal)
- input_node_probe = nengo.Probe(input_node)
- A_probe = nengo.Probe(A, synapse=0.01)
- A_squared_probe = nengo.Probe(A_squared, synapse=0.01)
- error_probe = nengo.Probe(error, synapse=0.01)
- learn_probe = nengo.Probe(stop_learning, synapse=None)
- with nengo.Simulator(model) as sim: # Create the simulator
- #****************
- # Run
- #****************
- sim.run(20) # seconds
- #****************
- # plot the results
- #****************
- t = sim.trange()
- # Plot the input signal
- plt.figure(figsize=(9, 9))
- plt.subplot(3, 1, 1)
- plt.plot(
- sim.trange(), sim.data[input_node_probe], label="Input", color="k", linewidth=2.0
- )
- plt.plot(
- sim.trange(),
- sim.data[learn_probe],
- label="Stop learning?",
- color="r",
- linewidth=2.0,
- )
- plt.legend(loc="lower right")
- plt.ylim(-1.2, 1.2)
- plt.subplot(3, 1, 2)
- plt.plot(
- sim.trange(), sim.data[input_node_probe] ** 2, label="Squared Input", linewidth=2.0
- )
- plt.plot(sim.trange(), sim.data[A_squared_probe], label="Decoded Ensemble $A^2$")
- plt.legend(loc="lower right")
- plt.ylim(-1.2, 1.2)
- plt.subplot(3, 1, 3)
- plt.plot(
- sim.trange(),
- sim.data[A_squared_probe] - sim.data[input_node_probe] ** 2,
- label="Error",
- )
- plt.legend(loc="lower right")
- plt.tight_layout()
- plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement