Guest User

Untitled

a guest
Aug 17th, 2018
78
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 7.96 KB | None | 0 0
  1. import os
  2. import argparse
  3. import random
  4. import numpy as np
  5. from sklearn.preprocessing import StandardScaler
  6. from sklearn.externals import joblib
  7. import tensorflow as tf
  8. from tensorflow.python import keras as K
  9. import gym
  10. from fn_framework import FNAgent, Trainer, Observer
  11.  
  12.  
  13. class PolicyGradientContinuousAgent(FNAgent):
  14.  
  15. def __init__(self, epsilon, low, high):
  16. super().__init__(epsilon, [low, high])
  17. self.scaler = None
  18. self._updater = None
  19.  
  20. def save(self, model_path):
  21. super().save(model_path)
  22. joblib.dump(self.scaler, self.scaler_path(model_path))
  23.  
  24. @classmethod
  25. def load(cls, env, model_path, epsilon=0.0001):
  26. low, high = [env.action_space.low[0], env.action_space.high[0]]
  27. agent = cls(epsilon, low, high)
  28. agent.model = K.models.load_model(model_path, custom_objects={
  29. "SampleLayer": SampleLayer})
  30. agent.scaler = joblib.load(agent.scaler_path(model_path))
  31. return agent
  32.  
  33. def scaler_path(self, model_path):
  34. fname, _ = os.path.splitext(model_path)
  35. fname += "_scaler.pkl"
  36. return fname
  37.  
  38. def initialize(self, experiences, actor_optimizer, critic_optimizer):
  39. self.scaler = StandardScaler()
  40. states = np.vstack([e.s for e in experiences])
  41. self.scaler.fit(states)
  42. feature_size = states.shape[1]
  43.  
  44. base = K.models.Sequential()
  45. base.add(K.layers.Dense(24, activation="relu",
  46. input_shape=(feature_size,)))
  47.  
  48. # Actor
  49. # define action distribution
  50. mu = K.layers.Dense(1, activation="tanh")(base.output)
  51. mu = K.layers.Lambda(lambda m: m * 2)(mu)
  52. #sigma = K.layers.Dense(1, activation="softplus")(base.output)
  53. #self.dist_model = K.Model(inputs=base.input, outputs=[mu, sigma])
  54. self.dist_model = K.Model(inputs=base.input, outputs=[mu])
  55.  
  56. # sample action from distribution
  57. low, high = self.actions
  58. action = SampleLayer(low, high)((mu))
  59. self.model = K.Model(inputs=base.input, outputs=[action])
  60.  
  61. # Critic
  62. self.critic = K.models.Sequential([
  63. K.layers.Dense(24, activation="relu", input_shape=(feature_size + 1,)),
  64. K.layers.Dense(1, activation="linear")
  65. ])
  66. self.set_updater(actor_optimizer)
  67. self.critic.compile(loss="mse", optimizer=critic_optimizer)
  68. self.initialized = True
  69. print("Done initialize. From now, begin training!")
  70.  
  71. def set_updater(self, optimizer):
  72. actions = tf.placeholder(shape=(None), dtype="float32")
  73. td_error = tf.placeholder(shape=(None), dtype="float32")
  74.  
  75. # Actor loss
  76. mu = self.dist_model.output
  77. action_dist = tf.distributions.Normal(loc=tf.squeeze(mu),
  78. scale=0.05)
  79. action_probs = action_dist.prob(tf.squeeze(actions))
  80. clipped = tf.clip_by_value(action_probs, 1e-10, 1.0)
  81. loss = - tf.log(clipped) * td_error
  82. loss = tf.reduce_mean(loss)
  83.  
  84. updates = optimizer.get_updates(loss=loss,
  85. params=self.model.trainable_weights)
  86. self._updater = K.backend.function(
  87. inputs=[self.model.input,
  88. actions, td_error],
  89. outputs=[loss, action_probs, mu],
  90. updates=updates)
  91.  
  92. def policy(self, s):
  93. if np.random.random() < self.epsilon or not self.initialized:
  94. low, high = self.actions
  95. return np.random.uniform(low, high)
  96. else:
  97. normalized_s = self.scaler.transform(s)
  98. action = self.model.predict(normalized_s)[0]
  99. return action[0]
  100.  
  101. def update(self, batch, gamma):
  102. states = np.vstack([e.s for e in batch])
  103. normalized_s = self.scaler.transform(states)
  104. actions = np.vstack([e.a for e in batch])
  105.  
  106. # Calculate value
  107. next_states = np.vstack([e.n_s for e in batch])
  108. normalized_n_s = self.scaler.transform(next_states)
  109. n_s_actions = self.model.predict(normalized_n_s)
  110. feature_n = np.concatenate([normalized_n_s, n_s_actions], axis=1)
  111. n_s_values = self.critic.predict(feature_n)
  112. values = [b.r + gamma * (0 if b.d else 1) * n_s_values
  113. for b, n_s_values in zip(batch, n_s_values)]
  114. values = np.array(values)
  115.  
  116. feature = np.concatenate([normalized_s, actions], axis=1)
  117. td_error = values - self.critic.predict(feature)
  118. a_loss, probs, mu = self._updater([normalized_s, actions, td_error])
  119. c_loss = self.critic.train_on_batch(feature, values)
  120.  
  121. print([a_loss, c_loss])
  122. """
  123. for x in zip(actions, mu, probs):
  124. print("Took action {}. (mu={}, its prob={})".format(*x))
  125. """
  126.  
  127.  
  128. class SampleLayer(K.layers.Layer):
  129.  
  130. def __init__(self, low, high, **kwargs):
  131. self.low = low
  132. self.high = high
  133. super(SampleLayer, self).__init__(**kwargs)
  134.  
  135. def build(self, input_shape):
  136. super(SampleLayer, self).build(input_shape)
  137.  
  138. def call(self, x):
  139. mu = x
  140. actions = tf.distributions.Normal(loc=tf.squeeze(mu),
  141. scale=0.05).sample([1])
  142. actions = tf.clip_by_value(actions, self.low, self.high)
  143. return tf.reshape(actions, (-1, 1))
  144.  
  145. def compute_output_shape(self, input_shape):
  146. return (input_shape[0], 1)
  147.  
  148. def get_config(self):
  149. config = super().get_config()
  150. config["low"] = self.low
  151. config["high"] = self.high
  152. return config
  153.  
  154.  
  155. class PendulumObserver(Observer):
  156.  
  157. def step(self, action):
  158. n_state, reward, done, info = self._env.step([action])
  159. return self.transform(n_state), reward, done, info
  160.  
  161. def transform(self, state):
  162. return np.reshape(state, (1, -1))
  163.  
  164.  
  165. class PolicyGradientContinuousTrainer(Trainer):
  166.  
  167. def __init__(self, buffer_size=4096, batch_size=32,
  168. gamma=0.9, report_interval=10, log_dir=""):
  169. super().__init__(buffer_size, batch_size, gamma,
  170. report_interval, log_dir)
  171.  
  172. def train(self, env, episode_count=220, epsilon=0.1, initial_count=-1,
  173. render=False):
  174. low, high = [env.action_space.low[0], env.action_space.high[0]]
  175. agent = PolicyGradientContinuousAgent(epsilon, low, high)
  176.  
  177. self.train_loop(env, agent, episode_count, initial_count, render)
  178. return agent
  179.  
  180. def begin_train(self, episode, agent):
  181. actor_optimizer = K.optimizers.Adam()
  182. critic_optimizer = K.optimizers.Adam()
  183. agent.initialize(self.experiences, actor_optimizer, critic_optimizer)
  184.  
  185. def step(self, episode, step_count, agent, experience):
  186. if self.training:
  187. batch = random.sample(self.experiences, self.batch_size)
  188. agent.update(batch, self.gamma)
  189.  
  190. def episode_end(self, episode, step_count, agent):
  191. reward = sum([e.r for e in self.get_recent(step_count)])
  192. self.reward_log.append(reward)
  193.  
  194. if self.is_event(episode, self.report_interval):
  195. recent_rewards = self.reward_log[-self.report_interval:]
  196. self.logger.describe("reward", recent_rewards, episode=episode)
  197.  
  198.  
  199. def main(play):
  200. env = PendulumObserver(gym.make("Pendulum-v0"))
  201. trainer = PolicyGradientContinuousTrainer()
  202. path = trainer.logger.path_of("policy_gradient_continuous_agent.h5")
  203.  
  204. if play:
  205. agent = PolicyGradientContinuousAgent.load(env, path)
  206. agent.play(env)
  207. else:
  208. trained = trainer.train(env, episode_count=100, render=False)
  209. trainer.logger.plot("Rewards", trainer.reward_log,
  210. trainer.report_interval)
  211. trained.save(path)
  212.  
  213.  
  214. if __name__ == "__main__":
  215. parser = argparse.ArgumentParser(description="PG Agent Pendulum-v0")
  216. parser.add_argument("--play", action="store_true",
  217. help="play with trained model")
  218.  
  219. args = parser.parse_args()
  220. main(args.play)
Add Comment
Please, Sign In to add comment