Advertisement
toweber

nengo-learncommunication_channel3

Sep 19th, 2022
1,160
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.81 KB | None | 0 0
  1. # https://www.nengo.ai/nengo/examples/learning/learn-communication-channel.html
  2. import os
  3. import ipdb # ipdb.set_trace()
  4.  
  5. import matplotlib.pyplot as plt
  6. import numpy as np
  7.  
  8. import nengo
  9. from nengo.processes import WhiteSignal
  10. from nengo.solvers import LstsqL2
  11.  
  12. model = nengo.Network()
  13. with model:
  14.     inp = nengo.Node(WhiteSignal(60, high=5), size_out=1)
  15.     #inp = nengo.Node(output=1.0, size_out=1)
  16.     pre = nengo.Ensemble(60, dimensions=1)
  17.     nengo.Connection(inp, pre)
  18.     post = nengo.Ensemble(60, dimensions=1)
  19.     conn = nengo.Connection(pre, post, function=lambda x: np.random.random(1))
  20.  
  21.     input_constant = nengo.Node(output=-1, size_out=1)
  22.     input_constant_Ensemble = nengo.Ensemble(100, dimensions=1)
  23.     nengo.Connection(input_constant, input_constant_Ensemble)
  24.  
  25.     inp_p = nengo.Probe(inp)
  26.     pre_p = nengo.Probe(pre, synapse=0.01)
  27.     post_p = nengo.Probe(post, synapse=0.01)
  28.  
  29.     # *********************
  30.     # calculating error
  31.     # *********************
  32.     error = nengo.Ensemble(60, dimensions=1)
  33.     error_p = nengo.Probe(error, synapse=0.03)
  34.  
  35.     # Error = actual - target = post - pre - 1
  36.     nengo.Connection(post, error)
  37.     nengo.Connection(pre, error, transform=-1)
  38.     nengo.Connection(input_constant, error, transform=-1)
  39.  
  40.     # Add the learning rule to the connection
  41.     conn.learning_rule_type = nengo.PES()
  42.  
  43.     # Connect the error into the learning rule
  44.     nengo.Connection(error, conn.learning_rule)
  45.  
  46. with nengo.Simulator(model) as sim:
  47.     sim.run(10.0)
  48.  
  49. plt.figure(figsize=(12, 8))
  50. plt.subplot(2, 1, 1)
  51. plt.plot(sim.trange(), sim.data[inp_p].T[0], c="k", label="Input")
  52. plt.plot(sim.trange(), sim.data[pre_p].T[0], c="b", label="Pre")
  53. plt.plot(sim.trange(), sim.data[post_p].T[0], c="r", label="Post")
  54. plt.ylabel("Dimension 1")
  55. plt.legend()
  56. plt.show()
  57.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement