Guest User

evolve_nn

a guest
Mar 8th, 2017
106
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.33 KB | None | 0 0
  1. """
  2. Single-pole balancing experiment using a discrete-time recurrent neural network.
  3. """
  4.  
  5. from __future__ import print_function
  6.  
  7. import os
  8. import pickle
  9.  
  10. import numpy as np
  11.  
  12. import cart_pole
  13. from neat import nn, parallel, population
  14. from neat.config import Config
  15. from math import radians as rad
  16.  
  17. runs_per_net = 5
  18. num_steps = 60000  # equivalent to 1 minute of simulation time
  19.  
  20.  
  21. def dpb_factory():
  22.     dpb = cart_pole.PoledCart(2)
  23.     dpb.pole_number = 2
  24.     p_masses = [0.1] * dpb.pole_number
  25.     p_angles = [0.0, rad(1.0)]
  26.     p_h_lens = [0.1, 1.0]
  27.     p_accels = [0.0] * dpb.pole_number
  28.     p_vels = [0.0] * dpb.pole_number
  29.     dpb.poles = []
  30.     for i in range(dpb.pole_number):
  31.         dpb.poles.append(cart_pole.Pole(p_angles[i],
  32.                                         p_vels[i],
  33.                                         p_accels[i],
  34.                                         p_masses[i],
  35.                                         p_h_lens[i]))
  36.     dpb.cart_pos = 0.0
  37.     dpb.cart_vel = 0.0
  38.     dpb.cart_acc = 0.0
  39.     dpb.cart_mass = 1.0
  40.     dpb.time = 0.0
  41.     dpb.applied_force = 0.0
  42.     dpb.track_limit = 2.4
  43.     dpb.p_failure_angle = rad(36)
  44.     dpb.time_step = 0.01
  45.     dpb.cart_fric = 0.05
  46.     dpb.p_fric = 0.000002
  47.     dpb.stop_at_zero_deg = True
  48.     return dpb
  49.  
  50.  
  51. def get_normilized_dpb_input_with_vel(dpb):
  52.     return np.clip([dpb.cart_pos / dpb.track_limit,
  53.                     dpb.cart_vel / 4.0,
  54.                     dpb.poles[0].angle / dpb.p_failure_angle,
  55.                     dpb.poles[0].vel / 5.0,
  56.                     dpb.poles[1].angle / dpb.p_failure_angle,
  57.                     dpb.poles[1].vel / 4.0], -1, 1)
  58.  
  59.  
  60. def get_normilized_dpb_input_no_vel(dpb):
  61.     return np.clip([dpb.cart_pos / dpb.track_limit,
  62.                     dpb.poles[0].angle / dpb.p_failure_angle,
  63.                     dpb.poles[1].angle / dpb.p_failure_angle], -1, 1)
  64.  
  65.  
  66. # Use the NN network phenotype and the discrete actuator force function.
  67. def evaluate_genome(g):
  68.     net = nn.create_feed_forward_phenotype(g)
  69.  
  70.     fitnesses = []
  71.  
  72.     for runs in range(runs_per_net):
  73.  
  74.         dpb = dpb_factory()
  75.  
  76.         # Run the given simulation for up to num_steps time steps.
  77.         fitness = 0.0
  78.  
  79.         for s in range(num_steps):
  80.  
  81.             # 3 no velocity, 6 with velocity
  82.             inputs = get_normilized_dpb_input_with_vel(dpb)
  83.             action = net.serial_activate(inputs)
  84.             # Apply action to the simulated cart-pole
  85.             force = cart_pole.discrete_actuator_force(action)
  86.             dpb.applied_force = force
  87.             dpb.update_state()
  88.             dpb.update_state()
  89.             if dpb.failed:
  90.                 break;
  91.  
  92.             fitness += 1.0
  93.  
  94.         fitnesses.append(fitness)
  95.  
  96.     # The genome's fitness is its worst performance across all runs.
  97.     return min(fitnesses)
  98.  
  99.  
  100. # Load the config file, which is assumed to live in
  101. # the same directory as this script.
  102. local_dir = os.path.dirname(__file__)
  103. config = Config(os.path.join(local_dir, 'nn_config'))
  104.  
  105. pop = population.Population(config)
  106. pe = parallel.ParallelEvaluator(4, evaluate_genome)
  107. pop.run(pe.evaluate, 100)
  108.  
  109. # Save the winner.
  110. print('Number of evaluations: {0:d}'.format(pop.total_evaluations))
  111. winner = pop.statistics.best_genome()
  112. with open('nn_winner_genome', 'wb') as f:
  113.     pickle.dump(winner, f)
  114.  
  115. print(winner)
Advertisement
Add Comment
Please, Sign In to add comment