Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import tensorflow.compat.v1 as tf
- from open_spiel.python.algorithms import dqn
- from open_spiel.python import rl_environment
- import random
- tf.disable_v2_behavior()
- game = "laser_tag"
- num_players = 2
- # env_configs = {"columns": 5, "rows": 5}
- env = rl_environment.Environment("laser_tag")
- num_actions = env.action_spec()["num_actions"]
- num_train_episodes = 10
- checkpoint_dir = '/Users/shawnsun/Desktop/CS486/Versions/models/dqn10'
- # Set DQN hyperparameters
- hidden_layers_sizes = 64
- replay_buffer_capacity = int(1e5)
- batch_size = 32
- num_train_episodes = int(1)
- # Train the DQN model
- with tf.Session() as sess:
- model = dqn.DQN(
- session=sess,
- player_id=0,
- state_representation_size=env.observation_spec()["info_state"][0],
- num_actions=num_actions,
- hidden_layers_sizes=hidden_layers_sizes,
- replay_buffer_capacity=replay_buffer_capacity,
- batch_size=batch_size)
- sess.run(tf.global_variables_initializer())
- for ep in range(num_train_episodes):
- time_step = env.reset()
- while not time_step.last():
- action = model.step(time_step).action
- # legal_actions = time_step.observation
- time_step = env.step([action, random.choice(range(num_actions))])
- model.step(time_step)
- model_variables = tf.trainable_variables()
- for var in model_variables:
- print(var)
- model.save(checkpoint_dir)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement