Advertisement
ec1117

Untitled

Jul 7th, 2021
726
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.65 KB | None | 0 0
  1. from dqn_agent import DQNAgent
  2. from tetris import Tetris
  3. from datetime import datetime
  4. from statistics import mean, median
  5. import random
  6. from logs import CustomTensorBoard
  7. from tqdm import tqdm
  8.        
  9.  
  10. # Run dqn with Tetris
  11. def dqn():
  12.     env = Tetris()
  13.     episodes = 2000
  14.     max_steps = None
  15.     epsilon_stop_episode = 1500
  16.     mem_size = 20000
  17.     discount = 0.95
  18.     batch_size = 512
  19.     epochs = 1
  20.     render_every = 50
  21.     log_every = 50
  22.     replay_start_size = 2000
  23.     train_every = 1
  24.     n_neurons = [32, 32]
  25.     render_delay = None
  26.     activations = ['relu', 'relu', 'linear']
  27.  
  28.     agent = DQNAgent(env.get_state_size(),
  29.                      n_neurons=n_neurons, activations=activations,
  30.                      epsilon_stop_episode=epsilon_stop_episode, mem_size=mem_size,
  31.                      discount=discount, replay_start_size=replay_start_size)
  32.  
  33.     log_dir = f'logs/tetris-nn={str(n_neurons)}-mem={mem_size}-bs={batch_size}-e={epochs}-{datetime.now().strftime("%Y%m%d-%H%M%S")}'
  34.     log = CustomTensorBoard(log_dir=log_dir)
  35.  
  36.     scores = []
  37.  
  38.     for episode in tqdm(range(episodes)):
  39.         current_state = env.reset()
  40.         done = False
  41.         steps = 0
  42.  
  43.         if render_every and episode % render_every == 0:
  44.             render = True
  45.         else:
  46.             render = False
  47.  
  48.         # Game
  49.         while not done and (not max_steps or steps < max_steps):
  50.             next_states = env.get_next_states()
  51.             best_state = agent.best_state(next_states.values())
  52.            
  53.             best_action = None
  54.             for action, state in next_states.items():
  55.                 if state == best_state:
  56.                     best_action = action
  57.                     break
  58.  
  59.             reward, done = env.play(best_action[0], best_action[1], render=render,
  60.                                     render_delay=render_delay)
  61.            
  62.             agent.add_to_memory(current_state, next_states[best_action], reward, done)
  63.             current_state = next_states[best_action]
  64.             steps += 1
  65.  
  66.         scores.append(env.get_game_score())
  67.  
  68.         # Train
  69.         if episode % train_every == 0:
  70.             agent.train(batch_size=batch_size, epochs=epochs)
  71.  
  72.         # Logs
  73.         if log_every and episode and episode % log_every == 0:
  74.             avg_score = mean(scores[-log_every:])
  75.             min_score = min(scores[-log_every:])
  76.             max_score = max(scores[-log_every:])
  77.  
  78.  
  79.             print(episode,avg_score, min_score,max_score)
  80.             # log.log(episode, avg_score=avg_score, min_score=min_score,
  81.             #         max_score=max_score)
  82.  
  83.  
  84. if __name__ == "__main__":
  85.     dqn()
  86.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement