Advertisement
jonksar

KERAS: Reinfrocement network

Aug 3rd, 2016
98
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.20 KB | None | 0 0
  1. import numpy as np
  2. import gym
  3.  
  4. from keras.models import Sequential
  5. from keras.layers import Dense, Activation, Flatten, Convolution2D
  6. from keras.optimizers import RMSprop
  7.  
  8. from rl.agents.dqn import DQNAgent
  9. from rl.policy import BoltzmannQPolicy
  10. from rl.memory import SequentialMemory
  11.  
  12.  
  13. ENV_NAME = 'CartPole-v0'
  14.  
  15. HLS = 100
  16.  
  17. # Setting the env up
  18. env = gym.make(ENV_NAME)
  19. n_action = env.action_space.n
  20.  
  21. # Build the neural network
  22. model = Sequential()
  23.  
  24. model.add(Flatten(input_shape=(1,) + env.observation_space.shape))
  25. model.add((Dense(HLS)))
  26. model.add(Activation('relu'))
  27. model.add((Dense(HLS)))
  28. model.add(Activation('relu'))
  29. model.add((Dense(HLS)))
  30. model.add(Activation('relu'))
  31. model.add(Dense(n_action))
  32. model.add(Activation('linear'))
  33.  
  34. print model.summary()  # How did it go? * ( ' ^')*
  35.  
  36. # Configure the RL agent
  37.  
  38. memory = SequentialMemory(limit=50000)
  39. policy = BoltzmannQPolicy()
  40. dqn = DQNAgent(model=model, nb_actions=n_action, memory=memory, nb_steps_warmup=10,
  41.                target_model_update=1e-2, policy=policy)
  42.  
  43. dqn.compile(optimizer=RMSprop(), metrics=['mae'])
  44.  
  45. dqn.fit(env, nb_steps=100000, verbose=2, visualize=False)
  46.  
  47. dqn.test(env, nb_episodes=5, visualize=True)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement