Guest User

Untitled

a guest
Oct 21st, 2017
79
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.22 KB | None | 0 0
  1. def train(model,epochs):
  2. # Train
  3. #Reseting the win counter
  4. win_cnt = 0
  5. # We want to keep track of the progress of the AI over time, so we save its win count history
  6. win_hist = []
  7. #Epochs is the number of games we play
  8. for e in range(epochs):
  9. loss = 0.
  10. #Resetting the game
  11. env.reset()
  12. game_over = False
  13. # get initial input
  14. input_t = env.observe()
  15.  
  16. while not game_over:
  17. #The learner is acting on the last observed game screen
  18. #input_t is a vector containing representing the game screen
  19. input_tm1 = input_t
  20.  
  21. #Take a random action with probability epsilon
  22. if np.random.rand() <= epsilon:
  23. #Eat something random from the menu
  24. action = np.random.randint(0, num_actions, size=1)
  25. else:
  26. #Choose yourself
  27. #q contains the expected rewards for the actions
  28. q = model.predict(input_tm1)
  29. #We pick the action with the highest expected reward
  30. action = np.argmax(q[0])
  31.  
  32. # apply action, get rewards and new state
  33. input_t, reward, game_over = env.act(action)
  34. #If we managed to catch the fruit we add 1 to our win counter
  35. if reward == 1:
  36. win_cnt += 1
  37.  
  38. #Uncomment this to render the game here
  39. #display_screen(action,3000,inputs[0])
  40.  
  41. """
  42. The experiences < s, a, r, sā€™ > we make during gameplay are our training data.
  43. Here we first save the last experience, and then load a batch of experiences to train our model
  44. """
  45.  
  46. # store experience
  47. exp_replay.remember([input_tm1, action, reward, input_t], game_over)
  48.  
  49. # Load batch of experiences
  50. inputs, targets = exp_replay.get_batch(model, batch_size=batch_size)
  51.  
  52. # train model on experiences
  53. batch_loss = model.train_on_batch(inputs, targets)
  54.  
  55. #sum up loss over all batches in an epoch
  56. loss += batch_loss
  57. win_hist.append(win_cnt)
  58. return win_hist
Add Comment
Please, Sign In to add comment