Advertisement
KDT85

tttrl

May 20th, 2024
395
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 6.68 KB | None | 0 0
  1. import numpy as np
  2. import time
  3. import pickle
  4. import pandas as pd
  5. from tqdm import tqdm
  6.  
  7. EMPTY = '_'
  8. GRID_SIZE = 3
  9. PLAYER = ['X', 'O']
  10.  
  11. def show_board(board):
  12.     board = np.array(board)
  13.     print()
  14.     for i in range(GRID_SIZE):
  15.         for j in range(GRID_SIZE):
  16.             print('|', end='')
  17.             print(board[i, j], end='')
  18.         print('|')
  19.     print()
  20.  
  21. def get_legal_moves(board):
  22.     legal_moves = []
  23.     for i in range(len(board)):
  24.         for j in range(len(board[i])):
  25.             if board[i][j] == EMPTY:
  26.                 legal_moves.append((i, j))
  27.     return legal_moves
  28.  
  29. def get_human_move(player):
  30.     human = input(f"Player {player}, enter a square: ")
  31.     valid = map_index(human)
  32.     while not valid:
  33.         print("Invalid input! Please enter a number between 1 and 9.")
  34.         return get_human_move(player)
  35.     return valid[1]
  36.  
  37. def map_index(int_choice):
  38.     num_pad_map = {'7': (0, 0), '8': (0, 1), '9': (0, 2), '4': (1, 0), '5': (1, 1), '6': (1, 2), '1': (2, 0), '2': (2, 1), '3': (2, 2)}
  39.     if int_choice in num_pad_map:
  40.         return True, num_pad_map[int_choice]
  41.     else:
  42.         return False
  43.  
  44. def make_move(move, board, player):
  45.     row, col = move
  46.     if board[row, col] == EMPTY:
  47.         board[row, col] = player
  48.     else:
  49.         print("That square is already taken!")
  50.         move = get_human_move(player)  # Ask for input again
  51.         make_move(move, board, player)
  52.     return board
  53.  
  54. def check_win(board):
  55.     # Check rows
  56.     for row in board:
  57.         if row[0] == row[1] == row[2] and row[0] != EMPTY:
  58.             return True, row[0]
  59.  
  60.     # Check columns
  61.     for col in range(3):
  62.         if board[0][col] == board[1][col] == board[2][col] and board[0][col] != EMPTY:
  63.             return True, board[0][col]
  64.  
  65.     # Check diagonals
  66.     if board[0][0] == board[1][1] == board[2][2] and board[0][0] != EMPTY:
  67.         return True, board[0][0]
  68.     if board[0][2] == board[1][1] == board[2][0] and board[0][2] != EMPTY:
  69.         return True, board[0][2]
  70.  
  71.     # Check draw
  72.     if all(board[i][j] != EMPTY for i in range(3) for j in range(3)):
  73.         return True, 'Draw'
  74.     return False, None
  75.  
  76. def state_to_index(state):
  77.     mapping = {'X': 1, 'O': -1, EMPTY: 0}
  78.     index = 0
  79.     for i, value in enumerate(state):
  80.         index += (3 ** i) * mapping[value]
  81.     return abs(index)
  82.  
  83. def get_action(legal_moves, Q_table, state_index, epsilon):
  84.     if np.random.rand() < epsilon:
  85.         action = np.random.choice(len(legal_moves))
  86.         return legal_moves[action]
  87.     else:
  88.         best_action = None
  89.         best_q_value = float('-inf')
  90.         for action in legal_moves:
  91.             action_index = action[0] * GRID_SIZE + action[1]
  92.             q_value = Q_table[state_index][action_index]
  93.             if q_value > best_q_value:
  94.                 best_q_value = q_value
  95.                 best_action = action
  96.         return best_action
  97.  
  98. def opponent(moves):
  99.     move = np.random.randint(len(moves))
  100.     return moves[move]
  101.  
  102. def train_agent(episodes):
  103.     q_table = np.zeros((3 ** (GRID_SIZE * GRID_SIZE), GRID_SIZE * GRID_SIZE))
  104.     start_time = time.time()
  105.     print("Training started...")
  106.  
  107.     for episode in tqdm(range(episodes), desc="Training Progress"):
  108.         winner = False
  109.         board = [[EMPTY] * GRID_SIZE for _ in range(GRID_SIZE)]
  110.         turn = 0
  111.         while not winner:
  112.             legal_moves = get_legal_moves(board)
  113.             state = tuple(np.array(board).flatten())
  114.             state_index = state_to_index(state)
  115.  
  116.             if PLAYER[turn % 2] == 'X':
  117.                 action = get_action(legal_moves, q_table, state_index, epsilon)
  118.                 row, col = action
  119.             else:
  120.                 row, col = opponent(legal_moves)
  121.  
  122.             board[row][col] = PLAYER[turn % 2]
  123.             winner, result = check_win(board)
  124.  
  125.             if winner:
  126.                 if result == 'X':
  127.                     reward = 1
  128.                 elif result == 'O':
  129.                     reward = -1
  130.                 else:
  131.                     reward = 0
  132.             else:
  133.                 reward = 0
  134.  
  135.             next_state = tuple(np.array(board).flatten())
  136.             next_state_index = state_to_index(next_state)
  137.             action_index = row * GRID_SIZE + col
  138.             q_table[state_index][action_index] += alpha * (reward + gamma * np.max(q_table[next_state_index]) - q_table[state_index][action_index])
  139.  
  140.             turn += 1
  141.  
  142.     end_time = time.time()
  143.     elapsed_time = end_time - start_time
  144.     print(f"Training time: {elapsed_time} seconds")
  145.     return q_table
  146.  
  147. def save_q_table(q_table, filename):
  148.     with open(filename, 'wb') as f:
  149.         pickle.dump(q_table, f)
  150.  
  151. def load_q_table(filename):
  152.     with open(filename, 'rb') as f:
  153.         q_table = pickle.load(f)
  154.     return q_table
  155.  
  156. def save_q_table_to_excel(q_table, filename):
  157.     df = pd.DataFrame(q_table)
  158.     df.to_excel(filename, index=False)
  159.     print(df)
  160.  
  161. def play_game(q_table, wins, losses, draws):
  162.     board = np.zeros((GRID_SIZE, GRID_SIZE), dtype=str)
  163.     board.fill(EMPTY)
  164.     game_over = False
  165.     round = 0
  166.  
  167.     while not game_over:
  168.         show_board(board)
  169.         if PLAYER[round % 2] == 'X':
  170.             state = tuple(np.array(board).flatten())
  171.             state_index = state_to_index(state)
  172.             move = get_action(get_legal_moves(board), q_table, state_index, 0)
  173.             print(f"AI played at {move}")
  174.         else:
  175.             #move = get_human_move(PLAYER[round % 2])
  176.             move = opponent(get_legal_moves(board))
  177.         board = make_move(move, board, PLAYER[round % 2])
  178.         game_over, result = check_win(board)
  179.         round += 1
  180.  
  181.     show_board(board)
  182.     if result == 'Draw':
  183.         print("It's a draw!")
  184.         draws += 1
  185.     else:
  186.         print(f"Player {result} wins!")
  187.         if result == 'X':
  188.             wins +=1
  189.         else:
  190.             losses +=1
  191.    
  192.     return (wins, losses, draws)
  193.  
  194. # Initialize Q-learning parameters
  195. alpha = 0.3  # Learning rate
  196. gamma = 0.9  # Discount factor
  197. epsilon = 1  # Exploration rate
  198. episodes = 20000000  # number of episodes for training
  199.  
  200. # Train the agent
  201. #q_table = train_agent(episodes)
  202.  
  203. # Save the Q-table
  204. #save_q_table(q_table, 'q_table.pkl')
  205. #save_q_table_to_excel(q_table, f'q_table_{episodes}_episodes.xlsx')
  206.  
  207. # Load the Q-table (if needed)
  208. q_table = load_q_table('q_table.pkl')
  209.  
  210. wins, losses, draws = 0, 0, 0
  211. # Play against the trained AI
  212. for i in range(1000):
  213.     wins, losses, draws = play_game(q_table, wins, losses, draws)
  214.  
  215.  
  216.  
  217. # Create a DataFrame
  218. df = pd.DataFrame({
  219.     'Result': ['Wins', 'Losses', 'Draws'],
  220.     'Count': [wins, losses, draws]
  221. })
  222.  
  223. # Print the DataFrame
  224. print(df.to_string(index=False))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement