Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import numpy as np
- import gym
- import tensorflow as tf
- from keras.models import Sequential
- from keras.layers import Dense, Activation, Flatten, Input
- from keras.optimizers import Adam
- from keras.initializers import he_normal
- from rl.agents.cem import CEMAgent
- from rl.memory import EpisodeParameterMemory
- from rl.memory import SequentialMemory
- from rl.agents.dqn import DQNAgent
- from rl.policy import EpsGreedyQPolicy
- from rl.policy import LinearAnnealedPolicy
- from nes_py.wrappers import JoypadSpace
- import gym_super_mario_bros
- from gym_super_mario_bros.actions import SIMPLE_MOVEMENT
- tf.compat.v1.disable_eager_execution()
- print(tf.executing_eagerly())
- env = gym_super_mario_bros.make('SuperMarioBros-v0')
- env = JoypadSpace(env, SIMPLE_MOVEMENT)
- np.random.seed(123)
- env.seed(123)
- nb_actions = env.action_space.n
- obs_dim = env.observation_space.shape[0]
- # Option 1 : Simple model
- #model = Sequential()
- #model.add(Flatten(input_shape=(1,) + env.observation_space.shape))
- #model.add(Dense(nb_actions))
- #model.add(Activation('softmax'))
- # Option 2: deep network
- model = Sequential()
- model.add(Flatten(input_shape=(1,) + env.observation_space.shape))
- model.add(Dense(16))
- model.add(Activation('relu'))
- model.add(Dense(16))
- model.add(Activation('relu'))
- model.add(Dense(16))
- model.add(Activation('relu'))
- model.add(Dense(nb_actions))
- model.add(Activation('softmax'))
- print(model.summary())
- model.compile(optimizer=Adam(), loss='mse')
- policy = LinearAnnealedPolicy(EpsGreedyQPolicy(), attr='eps', value_max=1., value_min=.1, value_test=.05, nb_steps=100000)
- memory = SequentialMemory(limit=50000, window_length=1)
- dqn = DQNAgent(model=model, nb_actions=nb_actions, memory=memory, nb_steps_warmup=10000, target_model_update=1e-2, policy=policy)
- dqn.compile(Adam(lr=1e-3), metrics=['mae'])
- dqn.fit(env, nb_steps=100000, visualize=True, verbose=2)
- dqn.test(env, nb_episodes=5, visualize=True)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement