Advertisement
Guest User

Untitled

a guest
Jul 25th, 2017
69
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 0.75 KB | None | 0 0
  1. """
  2. __name__ = predict.py
  3. __author__ = Yash Patel
  4. __description__ = Does the prediction using the defined model and data
  5. """
  6.  
  7. import gym
  8. import numpy as np
  9.  
  10. from data import gather_data
  11. from model import create_model
  12.  
  13. def predict():
  14. env = gym.make("CartPole-v0")
  15. trainingX, trainingY = gather_data(env)
  16. model = create_model()
  17. model.fit(trainingX, trainingY, epochs=5)
  18.  
  19. scores = []
  20. num_trials = 50
  21. sim_steps = 500
  22.  
  23. for trial in range(num_trials):
  24. observation = env.reset()
  25. score = 0
  26. for step in range(sim_steps):
  27. action = np.argmax(model.predict(observation.reshape(1,4)))
  28. observation, reward, done, _ = env.step(action)
  29. score += reward
  30. if done:
  31. break
  32. scores.append(score)
  33.  
  34. print(np.mean(scores))
  35.  
  36. if __name__ == "__main__":
  37. predict()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement