Advertisement
Guest User

Untitled

a guest
Nov 13th, 2019
118
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.92 KB | None | 0 0
  1. import numpy as np
  2. import gym
  3.  
  4. import tensorflow as tf
  5. from keras.models import Sequential
  6. from keras.layers import Dense, Activation, Flatten, Input
  7. from keras.optimizers import Adam
  8. from keras.initializers import he_normal
  9.  
  10. from rl.agents.cem import CEMAgent
  11. from rl.memory import EpisodeParameterMemory
  12. from rl.memory import SequentialMemory
  13.  
  14. from rl.agents.dqn import DQNAgent
  15. from rl.policy import EpsGreedyQPolicy
  16. from rl.policy import LinearAnnealedPolicy
  17.  
  18. from nes_py.wrappers import JoypadSpace
  19. import gym_super_mario_bros
  20. from gym_super_mario_bros.actions import SIMPLE_MOVEMENT
  21.  
  22. tf.compat.v1.disable_eager_execution()
  23. print(tf.executing_eagerly())
  24.  
  25. env = gym_super_mario_bros.make('SuperMarioBros-v0')
  26. env = JoypadSpace(env, SIMPLE_MOVEMENT)
  27.  
  28. np.random.seed(123)
  29. env.seed(123)
  30.  
  31. nb_actions = env.action_space.n
  32. obs_dim = env.observation_space.shape[0]
  33.  
  34. # Option 1 : Simple model
  35. #model = Sequential()
  36. #model.add(Flatten(input_shape=(1,) + env.observation_space.shape))
  37. #model.add(Dense(nb_actions))
  38. #model.add(Activation('softmax'))
  39.  
  40. # Option 2: deep network
  41. model = Sequential()
  42. model.add(Flatten(input_shape=(1,) + env.observation_space.shape))
  43. model.add(Dense(16))
  44. model.add(Activation('relu'))
  45. model.add(Dense(16))
  46. model.add(Activation('relu'))
  47. model.add(Dense(16))
  48. model.add(Activation('relu'))
  49. model.add(Dense(nb_actions))
  50. model.add(Activation('softmax'))
  51.  
  52. print(model.summary())
  53.  
  54.  
  55. model.compile(optimizer=Adam(), loss='mse')
  56.  
  57. policy = LinearAnnealedPolicy(EpsGreedyQPolicy(), attr='eps', value_max=1., value_min=.1, value_test=.05, nb_steps=100000)
  58. memory = SequentialMemory(limit=50000, window_length=1)
  59.  
  60. dqn = DQNAgent(model=model, nb_actions=nb_actions, memory=memory, nb_steps_warmup=10000, target_model_update=1e-2, policy=policy)
  61.  
  62. dqn.compile(Adam(lr=1e-3), metrics=['mae'])
  63.  
  64. dqn.fit(env, nb_steps=100000, visualize=True, verbose=2)
  65.  
  66. dqn.test(env, nb_episodes=5, visualize=True)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement