Advertisement
Guest User

Untitled

a guest
Apr 6th, 2020
218
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 5.16 KB | None | 0 0
  1. def get_legal_actions(str_state):
  2.     #TODO (1) : Get the actions Greuceanu can do
  3.     local_actions = deepcopy(ACTIONS)
  4.     return_actions = []
  5.    
  6.     for a in local_actions:
  7.         state, r, m = apply_action(str_state, a)
  8.         posx, posy = __get_position(state, "G")
  9.         if __is_valid_cell(state, posx, posy):
  10.             return_actions += [a]
  11.    
  12.     return return_actions
  13.  
  14. def epsilon_greedy(Q, state, legal_actions, epsilon):
  15.     # TODO (2) : Epsilon greedy
  16.     not_explored = []
  17.    
  18.     for a in legal_actions:
  19.         if (state, a) not in Q:
  20.             not_explored += [a]
  21.     if not_explored != []:
  22.         return choice(not_explored)
  23.    
  24.     max_action = -9999
  25.     max_a = ""
  26.    
  27.     for a in legal_actions:
  28.         if Q[(state, a)] > max_action:
  29.             max_action = Q[(state, a)]
  30.             max_a = a
  31.    
  32.     if random() <= epsilon:
  33.         return choice(legal_actions)
  34.     return max_a
  35.  
  36. def best_action(Q, state, legal_actions):
  37.     # TODO (3) : Best action
  38.     max_action = -9999
  39.     max_a = ""
  40.    
  41.     for a in legal_actions:
  42.         if (state, a) not in Q:
  43.             Q[(state, a)] = 0
  44.         if Q[(state, a)] > max_action:
  45.             max_action = Q[(state, a)]
  46.             max_a = a
  47.     return max_a
  48.  
  49. def q_learning():
  50.     Q = {}
  51.     train_scores = []
  52.     eval_scores = []
  53.     initial_state = get_initial_state(MAP_NAME)
  54.  
  55.     for train_ep in range(1, TRAIN_EPISODES+1):
  56.         clear_output(wait=True)
  57.         score = 0
  58.         state = deepcopy(initial_state)
  59.  
  60.         if VERBOSE:
  61.             display_state(state); sleep(SLEEP_TIME)
  62.             clear_output(wait=True)
  63.  
  64.         while not is_final_state(state, score):
  65.  
  66.             actions = get_legal_actions(state)
  67.             action = epsilon_greedy(Q, state, actions, EPSILON)
  68.  
  69.             new_state, reward, msg = apply_action(state, action)
  70.             score += reward
  71.            
  72.             max_action = -9999
  73.             new_actions = get_legal_actions(new_state)
  74.            
  75.             for a in new_actions:
  76.                 if (new_state, a) in Q:
  77.                     max_action = max(max_action, Q[(new_state, a)])
  78.                 else:
  79.                     max_action = max(max_action, 0)
  80.             if (new_state, action) in Q:
  81.                 if (new_state, max_action) not in Q:
  82.                     Q[(new_state, max_action)] = 0
  83.                    
  84.                 if (state, action) not in Q:
  85.                     Q[(state, action)] = 0
  86.                 Q[(state, action)] += LEARNING_RATE * (reward + DISCOUNT_FACTOR * max_action - Q[(state, action)])
  87.            
  88.             state = new_state
  89.             # TODO (1) : Q-Learning
  90.             if VERBOSE:
  91.                 print(msg); display_state(state); sleep(SLEEP_TIME)
  92.                 clear_output(wait=True)
  93.  
  94.  
  95.         print(f"Episode {train_ep} / {TRAIN_EPISODES}")
  96.         train_scores.append(score)
  97.  
  98.         # evaluate the greedy policy
  99.         if train_ep % EVAL_EVERY == 0:
  100.             avg_score = .0
  101.  
  102. #             TODO (4) : Evaluate
  103. #             eval_scores.append(avg_score)
  104.            
  105.             for i in range(1, EVAL_EPISODES + 1):
  106.                 state = deepcopy(initial_state)
  107.                 n_score = 0
  108.                 while not is_final_state(state, score):
  109.                     action = best_action(Q, state, get_legal_actions(state))
  110.                     new_state, reward, msg = apply_action(state, action)
  111.                     n_score += reward
  112.                    
  113.                     if (state, action) not in Q:
  114.                         Q[(state, action)] = 0
  115.                     new_actions = get_legal_actions(new_state)
  116.                     max_action = -9999
  117.                    
  118.                     for new_a in new_actions:
  119.                         if (new_state, new_a) not in Q:
  120.                             Q[(new_state, new_a)] = 0
  121.                         if Q[(new_state, new_a)] > max_action:
  122.                             max_action = Q[(new_state, new_a)]
  123.                     Q[(state, action)] += LEARNING_RATE * (reward + DISCOUNT_FACTOR * max_action - Q[(state, action)])
  124.                 avg_score += n_score
  125.             eval_scores += [avg_score / EVAL_EPISODES]
  126.     # --------------------------------------------------------------------------
  127.     if FINAL_SHOW:
  128.         state = deepcopy(initial_state)
  129.         while not is_final_state(state, score):
  130.             action = best_action(Q, state, get_legal_actions(state))
  131.             state, _, msg = apply_action(state, action)
  132.             print(msg); display_state(state); sleep(SLEEP_TIME)
  133.             clear_output(wait=True)
  134.  
  135.     if PLOT_SCORE:
  136.         from matplotlib import pyplot as plt
  137.         import numpy as np
  138.         plt.xlabel("Episode")
  139.         plt.ylabel("Average score")
  140.         plt.plot(
  141.             np.linspace(1, TRAIN_EPISODES, TRAIN_EPISODES),
  142.             np.convolve(train_scores, [0.2,0.2,0.2,0.2,0.2], "same"),
  143.             linewidth = 1.0, color = "blue"
  144.         )
  145.         plt.plot(
  146.             np.linspace(EVAL_EVERY, TRAIN_EPISODES, len(eval_scores)),
  147.                         eval_scores, linewidth = 2.0, color = "red"
  148.         )
  149.         plt.show()
  150.  
  151. q_learning()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement