Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # https://www.nengo.ai/nengo/examples/learning/learn-communication-channel.html
- import os
- import ipdb # ipdb.set_trace()
- import matplotlib.pyplot as plt
- import numpy as np
- import nengo
- from nengo.processes import WhiteSignal
- from nengo.solvers import LstsqL2
- model = nengo.Network()
- with model:
- inp = nengo.Node(WhiteSignal(60, high=5), size_out=2)
- pre = nengo.Ensemble(60, dimensions=2)
- nengo.Connection(inp, pre)
- post = nengo.Ensemble(60, dimensions=2)
- conn = nengo.Connection(pre, post, function=lambda x: np.random.random(2))
- inp_p = nengo.Probe(inp)
- pre_p = nengo.Probe(pre, synapse=0.01)
- post_p = nengo.Probe(post, synapse=0.01)
- # *********************
- # calculating error
- # *********************
- error = nengo.Ensemble(60, dimensions=2)
- error_p = nengo.Probe(error, synapse=0.03)
- # Error = actual - target = post - pre
- nengo.Connection(post, error)
- nengo.Connection(pre, error, transform=-1)
- # Add the learning rule to the connection
- conn.learning_rule_type = nengo.PES()
- # Connect the error into the learning rule
- nengo.Connection(error, conn.learning_rule)
- with nengo.Simulator(model) as sim:
- sim.run(10.0)
- plt.figure(figsize=(12, 8))
- plt.subplot(2, 1, 1)
- plt.plot(sim.trange(), sim.data[inp_p].T[0], c="k", label="Input")
- plt.plot(sim.trange(), sim.data[pre_p].T[0], c="b", label="Pre")
- plt.plot(sim.trange(), sim.data[post_p].T[0], c="r", label="Post")
- plt.ylabel("Dimension 1")
- plt.legend(loc="best")
- plt.subplot(2, 1, 2)
- plt.plot(sim.trange(), sim.data[inp_p].T[1], c="k", label="Input")
- plt.plot(sim.trange(), sim.data[pre_p].T[1], c="b", label="Pre")
- plt.plot(sim.trange(), sim.data[post_p].T[1], c="r", label="Post")
- plt.ylabel("Dimension 2")
- plt.legend(loc="best")
- plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement