Advertisement
Guest User

Untitled

a guest
Jul 27th, 2018
112
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 7.52 KB | None | 0 0
  1. import numpy as np
  2. import sys
  3. import random
  4.  
  5. from keras.models import Sequential
  6. from keras.layers import Dense, Flatten, Conv2D, Activation, MaxPooling2D, TimeDistributed, LSTM, Reshape
  7. from keras.optimizers import Adam, Adamax, Nadam
  8. from keras.backend import set_image_dim_ordering
  9. from absl import flags
  10.  
  11. from pysc2.env import sc2_env, environment
  12. from pysc2.lib import actions
  13. from pysc2.lib import features
  14.  
  15. from rl.memory import SequentialMemory
  16. from rl.policy import LinearAnnealedPolicy, EpsGreedyQPolicy
  17. from rl.core import Processor
  18. from rl.callbacks import FileLogger, ModelIntervalCheckpoint
  19. from rl.agents.dqn import DQNAgent
  20. from rl.agents.sarsa import SARSAAgent
  21.  
  22. # Actions from pySC2 API
  23.  
  24. FUNCTIONS = actions.FUNCTIONS
  25. _PLAYER_RELATIVE = features.SCREEN_FEATURES.player_relative.index
  26. _PLAYER_FRIENDLY = 1
  27. _PLAYER_NEUTRAL = 3
  28. _PLAYER_HOSTILE = 4
  29. _NO_OP = FUNCTIONS.no_op.id
  30. _MOVE_SCREEN = FUNCTIONS.Move_screen.id
  31. _ATTACK_SCREEN = FUNCTIONS.Attack_screen.id
  32. _SELECT_ARMY = FUNCTIONS.select_army.id
  33. _NOT_QUEUED = [0]
  34. _SELECT_ALL = [0]
  35. _HAL_ADEPT = FUNCTIONS.Hallucination_Adept_quick.id
  36. _HAL_ARCHON = FUNCTIONS.Hallucination_Archon_quick.id
  37. _HAL_COL = FUNCTIONS.Hallucination_Colossus_quick.id
  38. _HAL_DISRUP = FUNCTIONS.Hallucination_Disruptor_quick.id
  39. _HAL_HIGTEM = FUNCTIONS.Hallucination_HighTemplar_quick.id
  40. _HAL_IMN = FUNCTIONS.Hallucination_Immortal_quick.id
  41. _HAL_PHOENIX = FUNCTIONS.Hallucination_Phoenix_quick.id
  42. _HAL_STALKER = FUNCTIONS.Hallucination_Stalker_quick.id
  43. _HAL_VOIDRAID = FUNCTIONS.Hallucination_VoidRay_quick.id
  44. _HAL_ZEALOT = FUNCTIONS.Hallucination_Zealot_quick.id
  45. _FORCE_FIELD = FUNCTIONS.Effect_ForceField_screen.id
  46. _GUARD_FIELD = FUNCTIONS.Effect_GuardianShield_quick.id
  47.  
  48. # Size of the screen and length of the window
  49.  
  50. _SIZE = 64
  51. _WINDOW_LENGTH = 1
  52.  
  53. # Load and save weights for training
  54.  
  55. LOAD_MODEL = False  # True if the training process is already created
  56. SAVE_MODEL = True
  57.  
  58. # global variable
  59.  
  60. episode_reward = 0
  61. observation_cur = None
  62.  
  63. # Configure Flags for executing model from console
  64.  
  65. FLAGS = flags.FLAGS
  66. flags.DEFINE_string("mini-game", "HalucinIce", "Name of the minigame")
  67. flags.DEFINE_string("algorithm", "deepq", "RL algorithm to use")
  68.  
  69.  
  70. # Processor
  71.  
  72. class SC2Proc(Processor):
  73.     def process_observation(self, observation):
  74.         """Process the observation as obtained from the environment for use an agent and returns it"""
  75.         obs = observation[0].observation["feature_screen"][_PLAYER_RELATIVE]
  76.         return np.expand_dims(obs, axis=2)
  77.  
  78.     def process_state_batch(self, batch):
  79.         """Processes an entire batch of states and returns it"""
  80.         batch = np.swapaxes(batch, 0, 1)
  81.         return batch[0]
  82.  
  83.  
  84.  
  85. def args_random(actfunc):
  86.     # E.g. of actfunc: pysc2.lib.actions.FUNCTIONS[81]
  87.     args_given = []
  88.     for arg in actfunc.args:
  89.         arg_values = []
  90.         for size in arg.sizes:
  91.             if size == 0:
  92.                 arg_values.append(0)
  93.             else:
  94.                 arg_values.append(np.random.randint(0, size))
  95.         args_given.append(arg_values)
  96.     return args_given
  97.  
  98.  
  99.  
  100.  
  101. #  Define the environment
  102.  
  103.  
  104. class Environment(sc2_env.SC2Env):
  105.     """Starcraft II environmet. Implementation details in lib/features.py"""
  106.  
  107.     def step(self, action):
  108.         """Apply actions, step the world forward, and return observations"""
  109.         global episode_reward
  110.         global observation_cur
  111.         if observation_cur is None:
  112.             action = actions.FunctionCall(_SELECT_ARMY, [_SELECT_ALL])
  113.         else:
  114.             print('AVAILABLE ACTIONS:')
  115.             for a in observation_cur.available_actions:
  116.                 print(actions.FUNCTIONS[a])
  117.             function_id = np.random.choice(observation_cur.available_actions)
  118.             print('selected action id ', function_id)
  119.             args_chosen = args_random(actions.FUNCTIONS[function_id])
  120.             action = actions.FunctionCall(function_id, args_chosen)
  121.         obs = super(Environment, self).step([action])
  122.         observation_cur = obs[0].observation
  123.         observation = obs
  124.         r = obs[0].reward
  125.         done = obs[0].step_type == environment.StepType.LAST
  126.         episode_reward += r
  127.  
  128.         return observation, r, done, {}
  129.  
  130.     def reset(self):
  131.         global episode_reward
  132.         episode_reward = 0
  133.         super(Environment, self).reset()
  134.  
  135.         return super(Environment, self).step([actions.FunctionCall(_SELECT_ARMY, [_SELECT_ALL])])
  136.  
  137.  
  138. # def actions_to_choose():
  139. #     print('not using action from dqn.forward(observation)')
  140. #     hall = [_HAL_ADEPT, _HAL_ARCHON, _HAL_COL, _HAL_DISRUP,
  141. #             _HAL_HIGTEM, _HAL_IMN, _HAL_PHOENIX, _HAL_STALKER,
  142. #             _HAL_VOIDRAID, _HAL_ZEALOT, _FORCE_FIELD, _GUARD_FIELD]
  143. #     # action = actions.FunctionCall(_HAL_ADEPT, [_NOT_QUEUED])
  144. #     # action = actions.FunctionCall(_GUARD_FIELD, [_NOT_QUEUED])
  145. #     action = actions.FunctionCall(_SELECT_ARMY, [_SELECT_ALL])
  146. #     print(action)
  147. #     return action
  148.  
  149.     # TO-DO : Define actions_to_choose based on SC2 sentry unit
  150.  
  151. # Agent architecture using keras rl
  152.  
  153. def neural_network_model(input, actions):
  154.     model = Sequential()
  155.     # Define CNN model
  156.     print(input)
  157.     model.add(Conv2D(256, kernel_size=(5, 5), input_shape=input))
  158.     model.add(MaxPooling2D(pool_size=(2, 2), strides=None, padding='valid', data_format=None))
  159.     model.add(Flatten())
  160.  
  161.     model.add(Dense(256, activation='relu'))
  162.     model.add(Reshape((1, 256)))
  163.  
  164.     model.add(LSTM(256))
  165.     model.add(Dense(actions, activation='softmax'))
  166.     model.summary()
  167.     model.compile(loss="categorical_crossentropy",
  168.                   optimizer="adam",
  169.                   metrics=["accuracy"])
  170.  
  171.     return model
  172.  
  173.  
  174. def training_game():
  175.     env = Environment(map_name="HallucinIce", visualize=True, game_steps_per_episode=150,
  176.                       agent_interface_format=features.AgentInterfaceFormat(
  177.                           feature_dimensions=features.Dimensions(screen=64, minimap=32)
  178.                       ))
  179.  
  180.     input_shape = (_SIZE, _SIZE, 1)
  181.     nb_actions = 12  # Number of actions
  182.  
  183.     model = neural_network_model(input_shape, nb_actions)
  184.     memory = SequentialMemory(limit=5000, window_length=_WINDOW_LENGTH)
  185.  
  186.     processor = SC2Proc()
  187.  
  188.     # Policy
  189.  
  190.     policy = LinearAnnealedPolicy(EpsGreedyQPolicy(), attr="eps", value_max=1, value_min=0.7, value_test=.0,
  191.                                   nb_steps=1e6)
  192.  
  193.     # Agent
  194.  
  195.     dqn = DQNAgent(model=model, nb_actions=nb_actions, memory=memory, enable_double_dqn=False,
  196.                    nb_steps_warmup=500, target_model_update=1e-2, policy=policy,
  197.                    batch_size=150,
  198.                    processor=processor)
  199.  
  200.     dqn.compile(Adam(lr=.001), metrics=["mae"])
  201.  
  202.     # Save the parameters and upload them when needed
  203.  
  204.     name = "HallucinIce"
  205.     w_file = "dqn_{}_weights.h5f".format(name)
  206.     check_w_file = "train_w" + name + "_weights.h5f"
  207.  
  208.     if SAVE_MODEL:
  209.         check_w_file = "train_w" + name + "_weights_{step}.h5f"
  210.  
  211.     log_file = "training_w_{}_log.json".format(name)
  212.     callbacks = [ModelIntervalCheckpoint(check_w_file, interval=1000)]
  213.     callbacks += [FileLogger(log_file, interval=100)]
  214.  
  215.     if LOAD_MODEL:
  216.         dqn.load_weights(w_file)
  217.  
  218.     dqn.fit(env, callbacks=callbacks, nb_steps=1e7, action_repetition=2,
  219.             log_interval=1e4, verbose=2)
  220.  
  221.     dqn.save_weights(w_file, overwrite=True)
  222.     dqn.test(env, action_repetition=2, nb_episodes=30, visualize=False)
  223.  
  224.  
  225. if __name__ == '__main__':
  226.     FLAGS(sys.argv)
  227.     training_game()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement