Advertisement
Guest User

Untitled

a guest
Mar 23rd, 2018
84
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 7.08 KB | None | 0 0
  1. import math
  2. from statistics import mean
  3.  
  4. from numpy import random
  5. from time import sleep
  6. import numpy
  7. from operator import add
  8. import gym
  9. import matplotlib.pyplot as plt
  10.  
  11. from visualizer import plot_learning_result
  12.  
  13.  
  14. class QLearner:
  15.     def __init__(self, epsilon, alpha, gamma, epsilon_decay, alpha_k):
  16.         self.environment = gym.make('CartPole-v1')
  17.         self.attempt_no = 1
  18.         self.upper_bounds = [
  19.             self.environment.observation_space.high[0],
  20.             -3.0,
  21.             self.environment.observation_space.high[2],
  22.             -3.5
  23.         ]
  24.         self.lower_bounds = [
  25.             self.environment.observation_space.low[0],
  26.             3.0,
  27.             self.environment.observation_space.low[2],
  28.             3.5
  29.         ]
  30.  
  31.         self.bins_cartpos = [0.0]  # [-0.05, 0.05]
  32.         self.bins_cartvel = [0.0]  # [-0.42, 0.42]
  33.         self.bins_pole_angle = [-0.2+i*0.08 for i in range(7)]
  34.         self.bins_pole_vel = [-.8+i*0.27 for i in range(7)]
  35.  
  36.         self.epsilon = epsilon
  37.         self.alpha = alpha
  38.         self.gamma = gamma
  39.         self.epsilon_decay = epsilon_decay
  40.         self.alpha_k = alpha_k
  41.         self.minimal_epsilon = 0.02
  42.  
  43.         # self.stats = [{} for _ in range(4)]
  44.         self.all_observations = []
  45.  
  46.         self.knowledge = {(a, b, c, d, i): random.random() for a in range(3) for b in range(3) for c in range(8) for d in range(8) for i in range(2)}
  47.  
  48.     def learn(self, max_attempts):
  49.         return [self.attempt() for _ in range(max_attempts)]
  50.  
  51.     def feed_stats(self, obs):
  52.         for i, o in enumerate(obs):
  53.             if str(o) in self.stats[i]:
  54.                 self.stats[i][str(o)] += 1
  55.             else:
  56.                 self.stats[i][str(o)] = 0
  57.  
  58.     def attempt(self):
  59.         observation = self.discretise(self.environment.reset())
  60.         done = False
  61.         reward_sum = 0.0
  62.         while not done:
  63.             # self.environment.render()
  64.             action = self.pick_action(observation)
  65.             new_observation, reward, done, info = self.environment.step(action)
  66.             self.all_observations.append(new_observation)
  67.             new_observation = self.discretise(new_observation)
  68.             # self.feed_stats(new_observation)
  69.             # print(new_observation)
  70.             self.update_knowledge(action, observation, new_observation, reward)
  71.             observation = new_observation
  72.             reward_sum += reward
  73.         # print("{} {}".format(reward_sum, self.epsilon))
  74.         self.epsilon *= self.epsilon_decay
  75.         # if self.epsilon < self.minimal_epsilon:
  76.         #     self.epsilon = self.minimal_epsilon
  77.         #     self.epsilon_decay = 1
  78.         self.alpha *= self.alpha_k
  79.         self.attempt_no += 1
  80.         return reward_sum
  81.  
  82.     def discretise(self, observation):
  83.         obs_a = numpy.digitize(observation[0], self.bins_cartpos)
  84.         obs_b = numpy.digitize(observation[1], self.bins_cartvel)
  85.         obs_c = numpy.digitize(observation[2], self.bins_pole_angle)
  86.         obs_d = numpy.digitize(observation[3], self.bins_pole_vel)
  87.         return [obs_a[()], obs_b[()], obs_c[()], obs_d[()]]
  88.  
  89.     def pick_action(self, observation):
  90.         action_left = self.knowledge[(*observation, 0)]
  91.         action_right = self.knowledge[(*observation, 1)]
  92.         if random.random() > self.epsilon:
  93.             return 0 if action_left > action_right else 1
  94.         else:
  95.             return random.randint(0, 2)
  96.  
  97.     def update_knowledge(self, action, observation, new_observation, reward):
  98.         qval = self.knowledge[(*observation, action)]
  99.         maxqval = max(self.knowledge[(*new_observation, 0)], self.knowledge[(*new_observation, 1)])
  100.         new_qval = qval + self.alpha*(reward+self.gamma*maxqval-qval)
  101.         self.knowledge[(*observation, action)] = new_qval
  102.  
  103.     def plot_histogram(self):
  104.         for i in range(4):
  105.             plt.hist([o[i] for o in self.all_observations], bins=10)
  106.             plt.show()
  107.  
  108.  
  109. def main():
  110.     exp_count = 5
  111.     total_score = []
  112.  
  113.     # starting parameters
  114.     alpha = 0.9
  115.     gamma = 1.0
  116.     epsilon = 1.0
  117.     alpha_k = 0.996
  118.     epsilon_decay = 0.996
  119.  
  120.     # experiment parameters
  121.     # alpha_l = [0.1 + 0.1*i for i in range(9)]
  122.     # gamma_l = [0.0+0.2*i for i in range(6)]
  123.     # epsilon_l = [0.2 + 0.2*i for i in range(5)]
  124.     # alpha_k_l = [0.999, 0.996]
  125.     # epsilon_decay_l = [0.999, 0.996]
  126.     #
  127.     # best_params = [0.0, 0.0, 0.0, 0.0, 0.0]
  128.     # best_score = 0.0
  129.     # cnt = 0
  130.     # for a in alpha_l:
  131.     #     for b in gamma_l:
  132.     #         for c in epsilon_l:
  133.     #             for d in alpha_k_l:
  134.     #                 for e in epsilon_decay_l:
  135.     #                     sum = 0.0
  136.     #                     cnt += 1
  137.     #                     print(cnt)
  138.     #
  139.     #                     for i in range(exp_count):
  140.     #                         learner = QLearner(a, b, c, e, d)
  141.     #                         score = learner.learn(600)
  142.     #                         if len(total_score) == 0:
  143.     #                             total_score = [s for s in score]
  144.     #                         else:
  145.     #                             total_score = list(map(add, score, total_score))
  146.     #                         sum += mean(score[-50:])
  147.     #                     if sum > best_score:
  148.     #                         best_score = sum
  149.     #                         best_params = [a,b,c,d,e]
  150.     #                         print(best_params)
  151.     # print(best_params)
  152.  
  153.     hill = False
  154.     hill_params = [alpha, gamma, epsilon]
  155.     res = 0
  156.     res_float = 0.0
  157.  
  158.     if hill:
  159.         while res < 3:
  160.             hill_params[res] += 0.1
  161.             print(res)
  162.             print(hill_params[res])
  163.             total_score = []
  164.             for i in range(5):
  165.                 print(i)
  166.                 learner = QLearner(alpha, gamma, epsilon, epsilon_decay, alpha_k)
  167.                 score = learner.learn(1000)
  168.                 if len(total_score) == 0:
  169.                     total_score = [s for s in score]
  170.                 else:
  171.                     total_score = list(map(add, score, total_score))
  172.             res_temp = mean(total_score[-30:])
  173.             if res_temp < res_float or hill_params[res] > 1.0:
  174.                 res += 1
  175.                 res_float = 0.0
  176.             else:
  177.                 res_float = res_temp
  178.             print(res_temp)
  179.         print(hill_params)
  180.     else:
  181.         for i in range(exp_count):
  182.             learner = QLearner(epsilon, alpha, gamma, epsilon_decay, alpha_k)
  183.             score = learner.learn(2000)
  184.             if len(total_score) == 0:
  185.                 total_score = [s for s in score]
  186.             else:
  187.                 total_score = list(map(add, score, total_score))
  188.             print(mean(score[-20:]))
  189.         total_score = [t/5 for t in total_score]
  190.         # learner.plot_histogram()
  191.         plot_learning_result(total_score, 100,
  192.                              {'alpha': learner.alpha, 'gamma': learner.gamma, 'epsilon': learner.epsilon,
  193.                               'epsilon decay': learner.epsilon_decay, 'alpha decay': learner.alpha_k})
  194.  
  195.  
  196. if __name__ == '__main__':
  197.     main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement