Guest User

testingNeat

a guest
Jul 11th, 2019
48
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.65 KB | None | 0 0
  1. #include <iostream>
  2. #include <mlpack/methods/neat/neat.hpp>
  3. #include "environment.hpp"
  4.  
  5. using namespace gym;
  6. using namespace mlpack::neat;
  7.  
  8. class GymTask
  9. {
  10. public:
  11. GymTask()
  12. { /* Nothing to do here */ }
  13.  
  14. double Evaluate(Genome<> genome)
  15. {
  16. const std::string environment = "CartPole-v0";
  17. const std::string host = "127.0.0.1";
  18. const std::string port = "4040";
  19.  
  20. double totalReward = 0;
  21. size_t totalSteps = 0;
  22.  
  23. Environment env(host, port, environment);
  24. arma::mat observation = env.reset();
  25.  
  26. while (1)
  27. {
  28. arma::vec input = {observation(0, 0), observation(1, 0), observation(2, 0), observation(3, 0)};
  29. arma::vec output = genome.Evaluate(input);
  30. output = arma::clamp(output, 0, 1);
  31. output[0] = std::round(output[0]);
  32. arma::mat action(output);
  33. env.step(action);
  34.  
  35. observation = env.observation;
  36.  
  37. if (env.done)
  38. break;
  39. totalReward += env.reward;
  40. totalSteps += 1;
  41. }
  42.  
  43. std::cout << "Instance: " << env.instance << " total steps: " << totalSteps
  44. << " reward: " << totalReward << std::endl;
  45.  
  46. return totalReward;
  47. }
  48.  
  49. double EvaluateWithVid(Genome<> genome)
  50. {
  51. const std::string environment = "CartPole-v0";
  52. const std::string host = "127.0.0.1";
  53. const std::string port = "4040";
  54.  
  55. double totalReward = 0;
  56. size_t totalSteps = 0;
  57.  
  58. Environment env(host, port, environment);
  59. env.compression(0);
  60. env.monitor.start("./dummy/", true, true);
  61.  
  62. arma::mat observation = env.reset();
  63. env.render();
  64.  
  65. while (1)
  66. {
  67. arma::vec input = {observation(0, 0), observation(1, 0), observation(2, 0), observation(3, 0)};
  68. arma::vec output = genome.Evaluate(input);
  69. output = arma::clamp(output, 0, 1);
  70. output[0] = std::round(output[0]);
  71. arma::mat action(output);
  72. env.step(action);
  73.  
  74. if (env.done)
  75. break;
  76.  
  77. observation = env.observation;
  78.  
  79. totalReward += env.reward;
  80. totalSteps += 1;
  81. std::cout << "Current step: " << totalSteps << " current reward: "
  82. << totalReward << std::endl;
  83. }
  84.  
  85. std::cout << "Instance: " << env.instance << " total steps: " << totalSteps
  86. << " reward: " << totalReward << std::endl;
  87.  
  88. std::cout << "Video: https://kurg.org/media/gym/" << env.instance
  89. << " (it might take some minutes before the video is accessible)."
  90. << std::endl;
  91.  
  92. return totalReward;
  93. }
  94. };
  95.  
  96. int main(int argc, char* argv[])
  97. {
  98. GymTask task;
  99. NEAT<GymTask> model(task, 4, 1, 100, 20, 10);
  100.  
  101. Genome<> bestGenome = model.Train();
  102. double finalFitness = task.EvaluateWithVid(bestGenome);
  103. std::cout << finalFitness << std::endl;
  104. return 0;
  105. }
Add Comment
Please, Sign In to add comment