Advertisement
Guest User

Untitled

a guest
Oct 19th, 2019
104
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 6.36 KB | None | 0 0
  1. import os
  2. import sys
  3. import time
  4.  
  5. import numpy as np
  6. import tensorflow as tf
  7. from alternets import *
  8. from dopamine.agents.dqn import dqn_agent
  9.  
  10.  
  11. class DCAgent(dqn_agent.DQNAgent):
  12. def _build_train_op(self):
  13. replay_action_one_hot = tf.one_hot(
  14. self._replay.actions, self.num_actions, 1., 0., name='action_one_hot')
  15. replay_chosen_q = tf.reduce_sum(
  16. self._replay_net_outputs.q_values * replay_action_one_hot,
  17. reduction_indices=1,
  18. name='replay_chosen_q')
  19.  
  20. target = tf.stop_gradient(self._build_target_q_op())
  21. huber_loss = tf.losses.huber_loss(
  22. target, replay_chosen_q, reduction=tf.losses.Reduction.NONE)
  23.  
  24. total_loss = tf.math.add_n(list(map(tf.reduce_mean, tf.losses.get_losses())))
  25. if self.summary_writer is not None:
  26. with tf.variable_scope('Losses'):
  27. tf.summary.scalar('HuberLoss', tf.reduce_mean(huber_loss))
  28. tf.summary.scalar('Total_loss', total_loss)
  29. return self.optimizer.minimize(total_loss)
  30.  
  31. def bundle_and_checkpoint(self, checkpoint_dir, iteration_number):
  32. return None
  33.  
  34.  
  35. class fbAgent(dqn_agent.DQNAgent):
  36. def __init__(self, *args, **kwargs):
  37. super(fbAgent, self).__init__(*args, **kwargs)
  38. self.fb_data = {}
  39. self.fb_assign_ops = {}
  40. self.fb_a = 0.5
  41. self.fb_k = 30
  42. self._fb_init()
  43.  
  44. def _fb_init(self):
  45. self._fb_ph = tf.placeholder(tf.float32, shape=None, name="FBPH")
  46. for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Online'):
  47. self.fb_assign_ops[var.name] = tf.assign(var, self._fb_ph)
  48.  
  49. def _fb_back(self):
  50. for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Online'):
  51. if var.name in self.fb_data.keys():
  52. target_value = self._sess.run(var) * self.fb_a + self.fb_data[var.name] * (1 - self.fb_a)
  53. self._sess.run(self.fb_assign_ops[var.name], feed_dict={self._fb_ph: target_value})
  54. self.fb_data[var.name] = self._sess.run(var)
  55.  
  56. def _train_step(self):
  57. if self._replay.memory.add_count > self.min_replay_history and self.training_steps % (
  58. self.fb_k * self.update_period) == 0:
  59. self._fb_back()
  60. super(fbAgent, self)._train_step()
  61.  
  62. def bundle_and_checkpoint(self, checkpoint_dir, iteration_number):
  63. return None
  64.  
  65. class MIIMAgent(dqn_agent.DQNAgent):
  66. def _build_sync_op(self):
  67. """Builds ops for assigning weights from online to target network.
  68. Returns:
  69. ops: A list of ops assigning weights from online to target network.
  70. """
  71. # Get trainable variables from online and target DQNs
  72. sync_qt_ops = []
  73. trainables_online = tf.get_collection(
  74. tf.GraphKeys.TRAINABLE_VARIABLES, scope='Online')
  75. trainables_target = tf.get_collection(
  76. tf.GraphKeys.TRAINABLE_VARIABLES, scope='Target')
  77. for (w_online, w_target) in zip(trainables_online, trainables_target):
  78. # Assign weights from online to target network.
  79. middle=(w_online+w_target)/2
  80. sync_qt_ops.append(w_target.assign(middle, use_locking=True))
  81. sync_qt_ops.append(w_online.assign(middle, use_locking=True))
  82. return sync_qt_ops
  83.  
  84.  
  85.  
  86. import numpy as np
  87. import os
  88. from dopamine.agents.dqn import dqn_agent
  89. from dopamine.discrete_domains import run_experiment
  90. from dopamine.colab import utils as colab_utils
  91. from dopamine.discrete_domains import atari_lib
  92. import alteragents
  93. from alternets import *
  94. from absl import flags
  95. import gin.tf
  96. # import setGPU
  97. import tensorflow as tf
  98. import argparse
  99.  
  100. parser = argparse.ArgumentParser()
  101. parser.add_argument('-m', "--mod", action='store')
  102. parser.add_argument('-g', "--game", action='store')
  103. parser.add_argument('-f', "--folder", action='store')
  104. args = parser.parse_args()
  105. print(args)
  106. config_str = args.mod
  107.  
  108. tf.logging.set_verbosity(tf.logging.INFO)
  109. GAME = args.game # @param
  110. BASE_PATH = os.path.join("ae", str(args.folder), str(config_str),"data", str(GAME), str(np.random.randint(1000000)))
  111. metaAgent_config = """
  112. # Hyperparameters follow the classic Nature DQN, but we modify as necessary to
  113. # match those used in Rainbow (Hessel et al., 2018), to ensure apples-to-apples
  114. # comparison.
  115. import dopamine.discrete_domains.atari_lib
  116. import dopamine.discrete_domains.run_experiment
  117. import dopamine.replay_memory.circular_replay_buffer
  118. import gin.tf.external_configurables
  119.  
  120. atari_lib.create_atari_environment.game_name = '{}'
  121. # Sticky actions with probability 0.25, as suggested by (Machado et al., 2017).
  122. atari_lib.create_atari_environment.sticky_actions = True
  123. Runner.num_iterations = 40
  124. Runner.training_steps = 250000 # agent steps
  125. Runner.evaluation_steps = 125000 # agent steps
  126. Runner.max_steps_per_episode = 27000 # agent steps
  127.  
  128. WrappedReplayBuffer.replay_capacity = 1000000
  129. WrappedReplayBuffer.batch_size = 32
  130. """.format(GAME)
  131. gin.parse_config(metaAgent_config, skip_unknown=False)
  132. LOG_PATH = BASE_PATH
  133.  
  134. settings={}
  135.  
  136. settings['base']={'network':atari_lib.nature_dqn_network,'agent':alteragents.DCAgent,'target_update_period':8000}
  137. settings['fb']={'network':atari_lib.nature_dqn_network,'agent':alteragents.fbAgent,'target_update_period':80}
  138.  
  139. if args.mod in settings.keys():
  140. network=settings[args.mod]['network']
  141. agent = settings[args.mod]['agent']
  142. target_update_period = settings[args.mod]['target_update_period']
  143.  
  144. else:
  145. raise Exception("Mode?")
  146.  
  147.  
  148. def create_agent(sess, environment, summary_writer=None):
  149. return agent(sess, num_actions=environment.action_space.n, summary_writer=summary_writer,
  150. gamma=0.99, update_horizon=1, min_replay_history=20000
  151. , update_period=1, target_update_period=target_update_period, epsilon_train=0.01, epsilon_eval=0.001,
  152. epsilon_decay_period=250000, tf_device='/gpu:0', max_tf_checkpoints_to_keep=0,
  153. optimizer=tf.train.RMSPropOptimizer(
  154. learning_rate=0.00025,
  155. decay=0.95,
  156. momentum=0.0,
  157. epsilon=0.00001,
  158. centered=True), network=network)
  159.  
  160.  
  161. runner = run_experiment.TrainRunner(LOG_PATH, create_agent)
  162. print(LOG_PATH)
  163. runner.run_experiment()
  164. print(GAME+' Done training!')
  165. sys.exit()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement