Guest User

Q-Learning noob

a guest
Apr 4th, 2018
265
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 10.30 KB | None | 0 0
  1. import gym
  2. import os
  3. import random
  4. import time
  5. import pickle
  6. import tensorflow as tf
  7. import numpy as np
  8.  
  9. microtime = lambda: int(round(time.time() * 1000))
  10. start_t = microtime()
  11.  
  12. class ExperienceBuffer():
  13.     def __init__(self, buf_size):
  14.         self.buf_size = buf_size
  15.  
  16.         self.buffer = []
  17.  
  18.     # Store experiences
  19.     def add(self, experience):
  20.         if len(self.buffer) + 1 >= self.buf_size:
  21.             self.buffer.pop(0)
  22.  
  23.         self.buffer.append(experience)
  24.  
  25.     # Retrtieve a random sample of experiences
  26.     def sample(self, batch_size, trace_length):
  27.         sampled_episodes = random.sample(self.buffer, batch_size)
  28.         sampled_traces = []
  29.  
  30.         for episode in sampled_episodes:
  31.             p = len(episode) + 1 - trace_length
  32.  
  33.             if p <= 0:
  34.                 p = len(episode)
  35.  
  36.             pt = np.random.randint(0, p)
  37.             sampled_traces = sampled_traces + episode[pt:pt + trace_length]
  38.            
  39.         return np.array(sampled_traces)
  40.  
  41. class QNetwork():
  42.     def __init__(self, num_states, num_actions, save_file=None, gamma=0.99, lr=0.1):
  43.         self.num_states = num_states
  44.         self.num_actions = num_actions
  45.         self.save_file = save_file
  46.         self.gamma = gamma
  47.         self.lr = lr
  48.  
  49.         if save_file is not None:
  50.             try:
  51.                 load = pickle.load(open(save_file, "rb"))
  52.                 self.W = tf.Variable(load)
  53.                
  54.                 print("Loaded %s"  % save_file)
  55.                 print(load)
  56.                 print("")
  57.             except FileNotFoundError:
  58.                 self.W = tf.Variable(tf.random_uniform(
  59.                     [self.num_states, self.num_actions],
  60.                     0,
  61.                     0.01),
  62.                 dtype=tf.float32)
  63.             except Exception as _e:
  64.                 print(_e)
  65.         else:
  66.             self.W = tf.Variable(tf.random_uniform([self.num_states, self.num_actions]), dtype=tf.float32)
  67.  
  68.         self.input_state = tf.placeholder(shape=[None], dtype=tf.int32, name="input_state")
  69.         self.input_state_one_hot = tf.one_hot(
  70.             indices=tf.cast(self.input_state, tf.int32),
  71.             depth=self.num_states
  72.         )
  73.         self.Q = tf.matmul(self.input_state_one_hot, self.W)
  74.         self.Q_target = tf.placeholder(
  75.             shape=[None, self.num_actions],
  76.             dtype=tf.float32,
  77.             name="Q_target"
  78.         )
  79.         self.best_action = tf.argmax(self.Q, 1)
  80.  
  81.         self.loss = tf.reduce_sum(tf.square(self.Q_target - self.Q), 1)
  82.         self.trainer = tf.train.GradientDescentOptimizer(learning_rate=lr)
  83.         self.train_op = self.trainer.minimize(self.loss)
  84.  
  85.     def save(self, val):
  86.         if self.save_file is None:
  87.             return
  88.        
  89.         pickle.dump(val, open(self.save_file, "wb"))
  90.        
  91. # Setup
  92. train = True
  93. batch_train = True
  94. test = True
  95.  
  96. pre_train_steps = 50000
  97. train_freq = 25
  98.  
  99. num_episodes = 10000
  100. num_episodes_test = 100
  101. num_steps = 100
  102.  
  103. e_start = 0.1
  104. e_end = 0.01
  105.  
  106. #QN1 = QNetwork(16, 4, save_file="FrozenLake-v0.p", gamma=0.99, lr=0.1)
  107. QN1 = QNetwork(16, 4, gamma=0.99, lr=0.1)
  108.  
  109. # Variables
  110. env = gym.make("FrozenLake-v0")
  111. env = gym.wrappers.Monitor(env, "tmp/FrozenLake-0.1", force=True)
  112. exp_buf = ExperienceBuffer(1000)
  113.  
  114. e_factor = 2.0 * ((e_start - e_end) / num_episodes)
  115. e = e_start
  116.  
  117. bench = [[], [], [], [], []]
  118.  
  119. # Add an operation to initialize global variables.
  120. init_op = tf.global_variables_initializer()
  121.  
  122. # Training
  123. with tf.Session() as sess:
  124.     sess.run(init_op)
  125.    
  126.     if train == True:
  127.         print("Training started\n")
  128.  
  129.         batch_training_started = False
  130.         total_batch_trained = 0
  131.         all_rewards = []
  132.         all_steps = []
  133.        
  134.         for episode in range(num_episodes):
  135.             os.system("title \"Training... Episode %i/%i\"" % (episode, num_episodes))
  136.  
  137.             if episode % 100 == 0 and episode != 0:
  138.                 t = microtime()
  139.                 W_val = sess.run(QN1.W)
  140.                 QN1.save(W_val)
  141.                
  142.                 print("Episodes %04d - %04d: %i succeeded, %.2f avg steps/episode, e=%.4f" % (
  143.                         episode - 100,
  144.                         episode,
  145.                         sum(all_rewards[-100:]),
  146.                         np.mean(all_steps[-100:]),
  147.                         e
  148.                     )
  149.                 )
  150.                 bench[0].append((microtime() - t))
  151.  
  152.             # Reset episode-specific parameters
  153.             state = env.reset()
  154.             steps = 0
  155.             episode_reward = 0
  156.             episode_buffer = [] # s, a, r, s', d
  157.             done = False
  158.  
  159.             # Do steps in the game
  160.             while steps <= num_steps:
  161.                 if done == True:
  162.                     break
  163.  
  164.                 # Obtain the best action and current Q_values for this state
  165.                 t = microtime()
  166.                 act, curr_Qs = sess.run([QN1.best_action, QN1.Q], feed_dict={
  167.                     QN1.input_state: [state]
  168.                 })
  169.                 bench[1].append((microtime() - t))
  170.                 act = act[0]
  171.  
  172.                 # An e chance of randomly selection an action
  173.                 if np.random.rand(1) < e:
  174.                     act = env.action_space.sample()
  175.  
  176.                 # Advance a state
  177.                 t = microtime()
  178.                 new_state, reward, done, _ = env.step(act)
  179.                 bench[2].append((microtime() - t))
  180.  
  181.                 # Store this experience
  182.                 episode_buffer.append([state, act, reward, new_state, done])
  183.  
  184.                 # Train from memory
  185.                 total_steps = sum(all_steps)
  186.                 if (batch_train == True) and (total_steps > pre_train_steps) and ((total_steps % train_freq) == 0):
  187.                     if batch_training_started == False:
  188.                         batch_training_started = True
  189.                         print("Batch training started")
  190.                        
  191.                     training_batch = exp_buf.sample(4, 4)
  192.  
  193.                     t = microtime()
  194.                     training_states = [int(x[3]) for x in training_batch] # s'
  195.                     batch_new_Qs = sess.run(QN1.Q, feed_dict={
  196.                         QN1.input_state: training_states
  197.                     }) # Q(s', a')
  198.  
  199.                     training_states = [int(x[0]) for x in training_batch] # s
  200.                     batch_curr_Qs = sess.run(QN1.Q, feed_dict={
  201.                         QN1.input_state: training_states
  202.                     }) # Q(s, a)
  203.                     bench[3].append((microtime() - t))
  204.                    
  205.                     # Best possible outcome of the new states (per state)
  206.                     new_Qs_max = np.max(batch_new_Qs, 1) # max a' for Q(s', a')
  207.  
  208.                     target_Qs = batch_curr_Qs.copy()
  209.                     for i, experience in enumerate(training_batch):
  210.                         s, a, r, ss, d = experience # s a r s' d
  211.                        
  212.                         if int(r) == 1:
  213.                             e -= e_factor
  214.  
  215.                             if e < e_end:
  216.                                 e = e_end
  217.  
  218.                         target_Qs[i][int(a)] = r + QN1.gamma * new_Qs_max[i]
  219.                     # target for a = r + y*maxa'Q(s', a')
  220.  
  221.                     total_batch_trained += len(training_batch)
  222.                 else:
  223.                     # Obtain Q-values for the new state if we couldnt use the buffer
  224.                     t = microtime()
  225.                     new_Qs = sess.run(QN1.Q, feed_dict={
  226.                         QN1.input_state: [new_state]
  227.                     })
  228.                     bench[3].append((microtime() - t))
  229.  
  230.                     # Best possible outcome of the new state
  231.                     new_Qs_max = np.max(new_Qs)
  232.  
  233.                     # Set target_Qs for the old state
  234.                     target_Qs = curr_Qs.copy()
  235.                     target_Qs[0, act] = reward + QN1.gamma * new_Qs_max
  236.  
  237.                     training_states = [state]
  238.  
  239.                 # Train with the given state(s) and target_Qs
  240.                 t = microtime()
  241.                 sess.run(QN1.train_op, feed_dict={
  242.                     QN1.input_state: training_states,
  243.                     QN1.Q_target: target_Qs
  244.                 }) # train with target and s
  245.                 bench[4].append((microtime() - t))
  246.  
  247.                 steps += 1
  248.                 episode_reward += reward
  249.                 state = new_state
  250.  
  251.             # Decrease the random % for every successful run
  252.             if episode_reward > 0:
  253.                 e -= e_factor
  254.  
  255.                 if e < e_end:
  256.                     e = e_end
  257.  
  258.             all_rewards.append(episode_reward)
  259.             all_steps.append(steps)
  260.  
  261.             # Store this episode's experiences
  262.             exp_buf.add(episode_buffer)
  263.            
  264.         W_val = sess.run(QN1.W)
  265.         QN1.save(W_val)
  266.  
  267.         print("\nCompleted %i organic steps" % sum(all_steps))
  268.         print("Completed %i batch-trained steps" % total_batch_trained)
  269.  
  270.     if test == True:
  271.         # Testing
  272.         print("\nTesting...")
  273.        
  274.         all_rewards = []
  275.         all_steps = []
  276.        
  277.         for episode in range(num_episodes_test):
  278.             os.system("title \"Testing... Episode %i/%i\"" % (episode, num_episodes_test))
  279.  
  280.             # Reset episode-specific parameters
  281.             state = env.reset()
  282.             steps = 0
  283.             episode_reward = 0
  284.             done = False
  285.  
  286.             # Do steps in the game
  287.             while steps <= num_steps:
  288.                 if done == True:
  289.                     break
  290.  
  291.                 act = sess.run(QN1.best_action, feed_dict={
  292.                     QN1.input_state: [state]
  293.                 })
  294.                 act = act[0]
  295.  
  296.                 new_state, reward, done, _ = env.step(act)
  297.  
  298.                 steps += 1
  299.                 episode_reward += reward
  300.                 state = new_state
  301.  
  302.             all_rewards.append(episode_reward)
  303.             all_steps.append(steps)
  304.  
  305.         print("Finished. %i/%i succeeded, avg. steps %.2f" % (
  306.             sum(all_rewards),
  307.             num_episodes_test,
  308.             np.mean(all_steps)
  309.         ))
  310.  
  311. print("\nTimes:\nsave, get_act, step, get_new_Qs, train:")
  312. print(", ".join([str(sum(t)) for t in bench]))
  313.  
  314. print("\nTotal took %i ms" % (microtime() - start_t))
  315. env.close()
Add Comment
Please, Sign In to add comment