Advertisement
Guest User

Untitled

a guest
Nov 21st, 2019
98
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 6.16 KB | None | 0 0
  1. //
  2. // Created by LasseKB on 07/11/2019.
  3. //
  4.  
  5. // Based on https://radicalrafi.github.io/posts/pytorch-cpp-intro/, https://github.com/goldsborough/examples/blob/cpp/cpp/mnist/mnist.cpp,
  6. // and https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html
  7.  
  8. #include "mai_dqntrainer.hpp"
  9. #include "modes/standard_race.hpp" // This should probably not be here.
  10. #include <math.h>
  11. #include <torch/csrc/jit/import.h>
  12. #include <torch/csrc/jit/export.h>
  13.  
  14. #define BATCH_SIZE 128
  15. #define GAMMA 0.999
  16. #define EPS_START 0.9
  17. #define EPS_END 0.05
  18. #define EPS_DECAY 200
  19. #define TARGET_UPDATE 10
  20.  
  21. MAIDQNTrainer::MAIDQNTrainer(MAIDQNModel *model) {
  22.     m_policyNet = model;
  23.     m_targetNet = new MAIDQNModel(model->getKartID());
  24.  
  25.     // Make targetNet's weights the same as policyNet's
  26.     if (!m_policyNet->getModule()->is_serializable()) throw 909;
  27.     //torch::save(m_policyNet->getModule(), "tempPolicyToTarget.pt");
  28.     auto compilationUnit = std::make_shared<torch::jit::script::CompilationUnit>();
  29.     torch::serialize::OutputArchive outArchive = torch::serialize::OutputArchive(compilationUnit);
  30.     //(dynamic_cast<torch::nn::Module*>(m_policyNet))->save(outArchive);
  31.     m_policyNet->getModule()->save(outArchive);
  32.     outArchive.save_to("tempPolicyToTarget.pt");
  33.     torch::serialize::InputArchive inArchive = torch::serialize::InputArchive();
  34.     //torch::load(*(m_targetNet->getModule()), "tempPolicyToTarget.pt");
  35.     inArchive.load_from("tempPolicyToTarget.pt");
  36.     m_targetNet->getModule()->load(inArchive);
  37.  
  38.     m_stepsDone = 0;
  39.     srand(time(NULL));
  40.     torch::optim::RMSprop(m_policyNet->getModule()->parameters(), 0.1);
  41.     m_optimiser = dynamic_cast<torch::optim::Optimizer*>(new torch::optim::RMSprop(m_policyNet->getModule()->parameters(), torch::optim::RMSpropOptions(0.01)));
  42. }
  43.  
  44. PlayerAction MAIDQNTrainer::selectAction(float state) {
  45.     float sample = (rand() % 100) / 100.0f;
  46.     float eps_threshold = EPS_END + (EPS_START - EPS_END) * pow(M_E, -1. * m_stepsDone / EPS_DECAY);
  47.     m_stepsDone += 1;
  48.     if (sample > eps_threshold)
  49.         return m_policyNet->getAction(state);
  50.     else
  51.     {
  52.         int ind = rand() % m_policyNet->getNumActions();
  53.         return m_policyNet->getAction(ind);
  54.     }
  55. }
  56.  
  57. void MAIDQNTrainer::optimiseModel() {
  58.     if (replayMemory.states.size() < BATCH_SIZE) return;
  59.  
  60.     torch::Tensor stateTensor /*= torch::zeros({128, 1})*/;
  61.     torch::Tensor actionTensor;
  62.     torch::Tensor rewardTensor;
  63.     torch::Tensor nextStateTensor;
  64.  
  65.     for (int i = 0; i < BATCH_SIZE; i++) {
  66.         //std::cout << i << "\n";
  67.         int sample = rand() % replayMemory.states.size();
  68.         //torch::Tensor stateActionValues = m_policyNet->pseudoForward(replayMemory.states[sample]).max();
  69.  
  70.         //std::cout << stateActionValues << "\n";
  71.         ////std::cout << stateActionValues.max() << "\n\n";
  72.  
  73.         //torch::Tensor nextStateValues = torch::zeros(1);
  74.         //nextStateValues[0] = m_targetNet->pseudoForward(replayMemory.nextStates[sample]).max();
  75.         //torch::Tensor expectedStateActionValues = (nextStateValues * GAMMA) + replayMemory.rewards[sample];
  76.  
  77.         //auto loss = torch::smooth_l1_loss(stateActionValues, expectedStateActionValues.detach());
  78.  
  79.         //m_optimiser->zero_grad();
  80.         //loss.backward();
  81.         //m_optimiser->step();
  82.  
  83.         if (i == 0) {
  84.             stateTensor = torch::tensor(replayMemory.states[sample]);
  85.             actionTensor = torch::tensor(replayMemory.actions[sample]);
  86.             rewardTensor = torch::tensor(replayMemory.rewards[sample]);
  87.             nextStateTensor = torch::tensor(replayMemory.nextStates[sample]);
  88.             continue;
  89.         }
  90.  
  91.         stateTensor = torch::cat({ stateTensor, torch::tensor(replayMemory.states[sample]) }, 0);
  92.         actionTensor = torch::cat({ actionTensor, torch::tensor(replayMemory.actions[sample]) }, 0);
  93.         rewardTensor = torch::cat({ rewardTensor, torch::tensor(replayMemory.rewards[sample]) }, 0);
  94.         nextStateTensor = torch::cat({ nextStateTensor, torch::tensor(replayMemory.nextStates[sample]) }, 0);
  95.     }
  96.  
  97.     /*torch::Tensor zeros = torch::zeros({ 128, 1 });
  98.     std::cout << zeros << "\n";
  99.     torch::Tensor hej = m_policyNet->forward(zeros);
  100.     std::cout << hej << "\n";*/
  101.  
  102.     stateTensor = stateTensor.reshape({ 128, 1 });
  103.     //std::cout << stateTensor << "\n";
  104.     actionTensor = torch::_cast_Long(actionTensor);
  105.     actionTensor = actionTensor.reshape({ 128, 1 });
  106.     nextStateTensor = nextStateTensor.reshape({ 128,1 });
  107.  
  108.     torch::Tensor stateActionValues = m_policyNet->forward(stateTensor).gather(1, actionTensor);
  109.     torch::Tensor nextStateValues = std::get<0>(m_targetNet->forward(nextStateTensor).max(1)).detach();
  110.     torch::Tensor expectedStateValues = (nextStateValues * GAMMA) + rewardTensor;
  111.  
  112.     auto loss = torch::smooth_l1_loss(stateActionValues, expectedStateValues.unsqueeze(1));
  113.  
  114.     m_optimiser->zero_grad();
  115.     loss.backward();
  116.     m_optimiser->step();
  117. }
  118.  
  119. void MAIDQNTrainer::run() {
  120.     for (int i = 0; i < 50; i++) {
  121.         // TODO reset world
  122.         World *world = World::getWorld();
  123.         StandardRace* srWorld = dynamic_cast<StandardRace*>(world);
  124.         float state = srWorld->getDistanceDownTrackForKart(m_policyNet->getKartID(), /*Account for checklines? WTH is this?*/false);
  125.         float startState = state;
  126.         bool raceDone = false;
  127.         while (!raceDone) {
  128.             PlayerAction action = selectAction(state);
  129.             // TODO perform the action
  130.             float nextState = state + 1.0f; // Placeholder
  131.             float reward = 1.0f; // Placeholder
  132.             bool done = (state > startState + 50.0f); // Placeholder // Floating point comparison might be a problem :(
  133.            
  134.             replayMemory.states.push_back(state);
  135.             replayMemory.actions.push_back(action);
  136.             replayMemory.nextStates.push_back(nextState);
  137.             replayMemory.rewards.push_back(reward);
  138.  
  139.             state = nextState;
  140.  
  141.             optimiseModel();
  142.             if (done) break;
  143.         }
  144.         if (i % TARGET_UPDATE == 0) {
  145.             // Make targetNet's weights the same as policyNet's
  146.             if (!m_policyNet->getModule()->is_serializable()) throw 909;
  147.             auto compilationUnit = std::make_shared<torch::jit::script::CompilationUnit>();
  148.             torch::serialize::OutputArchive outArchive = torch::serialize::OutputArchive(compilationUnit);
  149.             m_policyNet->getModule()->save(outArchive);
  150.             outArchive.save_to("tempPolicyToTarget.pt");
  151.             torch::serialize::InputArchive inArchive;
  152.             inArchive.load_from("tempPolicyToTarget.pt");
  153.             m_targetNet->getModule()->load(inArchive);
  154.         }
  155.     }
  156.     std::cout << "Running is fun :D\n";
  157. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement