Advertisement
toweber

nengo-learncommunication_channel2

Sep 19th, 2022
855
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.80 KB | None | 0 0
  1. # https://www.nengo.ai/nengo/examples/learning/learn-communication-channel.html
  2. import os
  3. import ipdb # ipdb.set_trace()
  4.  
  5. import matplotlib.pyplot as plt
  6. import numpy as np
  7.  
  8. import nengo
  9. from nengo.processes import WhiteSignal
  10. from nengo.solvers import LstsqL2
  11.  
  12. model = nengo.Network()
  13. with model:
  14.     inp = nengo.Node(WhiteSignal(60, high=5), size_out=2)
  15.     pre = nengo.Ensemble(60, dimensions=2)
  16.     nengo.Connection(inp, pre)
  17.     post = nengo.Ensemble(60, dimensions=2)
  18.     conn = nengo.Connection(pre, post, function=lambda x: np.random.random(2))
  19.     inp_p = nengo.Probe(inp)
  20.     pre_p = nengo.Probe(pre, synapse=0.01)
  21.     post_p = nengo.Probe(post, synapse=0.01)
  22.  
  23.     # *********************
  24.     # calculating error
  25.     # *********************
  26.     error = nengo.Ensemble(60, dimensions=2)
  27.     error_p = nengo.Probe(error, synapse=0.03)
  28.  
  29.     # Error = actual - target = post - pre
  30.     nengo.Connection(post, error)
  31.     nengo.Connection(pre, error, transform=-1)
  32.  
  33.     # Add the learning rule to the connection
  34.     conn.learning_rule_type = nengo.PES()
  35.  
  36.     # Connect the error into the learning rule
  37.     nengo.Connection(error, conn.learning_rule)
  38.  
  39. with nengo.Simulator(model) as sim:
  40.     sim.run(10.0)
  41.  
  42. plt.figure(figsize=(12, 8))
  43. plt.subplot(2, 1, 1)
  44. plt.plot(sim.trange(), sim.data[inp_p].T[0], c="k", label="Input")
  45. plt.plot(sim.trange(), sim.data[pre_p].T[0], c="b", label="Pre")
  46. plt.plot(sim.trange(), sim.data[post_p].T[0], c="r", label="Post")
  47. plt.ylabel("Dimension 1")
  48. plt.legend(loc="best")
  49. plt.subplot(2, 1, 2)
  50. plt.plot(sim.trange(), sim.data[inp_p].T[1], c="k", label="Input")
  51. plt.plot(sim.trange(), sim.data[pre_p].T[1], c="b", label="Pre")
  52. plt.plot(sim.trange(), sim.data[post_p].T[1], c="r", label="Post")
  53. plt.ylabel("Dimension 2")
  54. plt.legend(loc="best")
  55.  
  56. plt.show()
  57.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement