Advertisement
Guest User

tran.py

a guest
Dec 6th, 2023
16
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.43 KB | None | 0 0
  1. import tensorflow.compat.v1 as tf
  2. from open_spiel.python.algorithms import dqn
  3. from open_spiel.python import rl_environment
  4. import random
  5.  
  6. tf.disable_v2_behavior()
  7.  
  8. game = "laser_tag"
  9. num_players = 2
  10.  
  11. # env_configs = {"columns": 5, "rows": 5}
  12. env = rl_environment.Environment("laser_tag")
  13. num_actions = env.action_spec()["num_actions"]
  14. num_train_episodes = 10
  15.  
  16. checkpoint_dir = '/Users/shawnsun/Desktop/CS486/Versions/models/dqn10'
  17.  
  18.  
  19.  
  20. # Set DQN hyperparameters
  21. hidden_layers_sizes = 64
  22. replay_buffer_capacity = int(1e5)
  23. batch_size = 32
  24. num_train_episodes = int(1)
  25.  
  26. # Train the DQN model
  27. with tf.Session() as sess:
  28. model = dqn.DQN(
  29. session=sess,
  30. player_id=0,
  31. state_representation_size=env.observation_spec()["info_state"][0],
  32. num_actions=num_actions,
  33. hidden_layers_sizes=hidden_layers_sizes,
  34. replay_buffer_capacity=replay_buffer_capacity,
  35. batch_size=batch_size)
  36. sess.run(tf.global_variables_initializer())
  37.  
  38. for ep in range(num_train_episodes):
  39. time_step = env.reset()
  40. while not time_step.last():
  41. action = model.step(time_step).action
  42. # legal_actions = time_step.observation
  43. time_step = env.step([action, random.choice(range(num_actions))])
  44.  
  45. model.step(time_step)
  46.  
  47. model_variables = tf.trainable_variables()
  48. for var in model_variables:
  49. print(var)
  50. model.save(checkpoint_dir)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement