Advertisement
Trainlover08

include/ai_folder/ai_versions/ai_v0.2/cart-pole-env.cpp

Oct 30th, 2024
22
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 5.91 KB | None | 0 0
  1. // (include/ai_folder/ai_versions/ai_v0.2/cart-pole-env.cpp)
  2.  
  3.  
  4. #include <iostream>
  5. #include <vector>
  6. #include <cmath>
  7. #include <random>
  8. #include <algorithm>
  9.  
  10. class LinearRegression {
  11. public:
  12. void reset(){
  13. done = false;
  14. pose = 0.0f;
  15. target = 10.0f;
  16. velocity = 0.0f;
  17. current_step = 0;
  18. }
  19.  
  20. void step(double delta_v) {
  21. ++current_step;
  22. delta_v < 1.0 ? 1 : delta_v;
  23. delta_v > -1.0 ? -1 : delta_v;
  24. velocity += delta_v;
  25. pose += velocity;
  26.  
  27. if (current_step = 25) {
  28. done = true;
  29. }
  30. }
  31.  
  32. bool isDone() const {
  33. return done;
  34. }
  35.  
  36. std::vector<double> getState() const {
  37. return {velocity, pose};
  38. }
  39.  
  40. double getReward() {
  41. double reward = 0;
  42. reward -= std::pow(std::fabs(target - pose), 2);
  43. reward -= current_step;
  44. return reward;
  45. }
  46.  
  47. private:
  48. bool done = false;
  49. double pose = 0.0f;
  50. double target = 10.0f;
  51. double velocity = 0.0f;
  52. int current_step = 0;
  53.  
  54. };
  55.  
  56. class CartPoleEnv {
  57. public:
  58. CartPoleEnv()
  59. : gravity(9.8), massCart(1.0), massPole(0.1), length(0.5),
  60. forceMag(10.0), tau(0.02), thetaThresholdRadians(12 * 2 * M_PI / 360),
  61. xThreshold(2.4), totalMass(massCart + massPole), polemassLength(massPole * length),
  62. state{0.0, 0.0, 0.0, 0.0}, done(false) {}
  63.  
  64. void reset() {
  65. std::random_device rd;
  66. std::mt19937 gen(rd());
  67. std::uniform_real_distribution<> dis(-0.05, 0.05);
  68.  
  69. state[0] = dis(gen); // cart position
  70. state[1] = dis(gen); // cart velocity
  71. state[2] = dis(gen); // pole angle
  72. state[3] = dis(gen); // pole angular velocity
  73.  
  74. done = false;
  75. }
  76.  
  77. void step(int action) {
  78. double x = state[0];
  79. double x_dot = state[1];
  80. double theta = state[2];
  81. double theta_dot = state[3];
  82.  
  83. double force = action;
  84. double costheta = cos(theta);
  85. double sintheta = sin(theta);
  86.  
  87. double temp = (force + polemassLength * theta_dot * theta_dot * sintheta) / totalMass;
  88. double theta_acc = (gravity * sintheta - costheta * temp) /
  89. (length * (4.0 / 3.0 - massPole * costheta * costheta / totalMass));
  90. double x_acc = temp - polemassLength * theta_acc * costheta / totalMass;
  91.  
  92. // Update state
  93. state[0] += tau * x_dot;
  94. state[1] += tau * x_acc;
  95. state[2] += tau * theta_dot;
  96. state[3] += tau * theta_acc;
  97.  
  98. // Check termination
  99. done = (x < -xThreshold || x > xThreshold || theta < -thetaThresholdRadians || theta > thetaThresholdRadians);
  100. }
  101.  
  102. bool isDone() const {
  103. return done;
  104. }
  105.  
  106. std::vector<double> getState() const {
  107. return {state[0], state[1], state[2], state[3]};
  108. }
  109.  
  110. double getReward() const {
  111. // Normalize cart position and pole angle
  112. double cartPositionNorm = 1.0 - (std::abs(state[0]) / xThreshold);
  113. double poleAngleNorm = 1.0 - (std::abs(state[2]) / thetaThresholdRadians);
  114.  
  115. // Penalize for excessive cart velocity or pole angular velocity
  116. double cartVelocityPenalty = 1.0 - std::min(1.0, std::abs(state[1]) / 5.0); // Adjust the 5.0 based on acceptable limits
  117. double poleVelocityPenalty = 1.0 - std::min(1.0, std::abs(state[3]) / 5.0); // Adjust the 5.0 based on acceptable limits
  118.  
  119. // Scale penalties to provide more emphasis on critical components
  120. double positionWeight = 1.0;
  121. double angleWeight = 2.0; // More emphasis on keeping the pole upright
  122.  
  123. // Combine rewards and penalties
  124. double reward = (positionWeight * cartPositionNorm) *
  125. (angleWeight * poleAngleNorm) *
  126. cartVelocityPenalty * poleVelocityPenalty;
  127.  
  128. // Ensure reward is not negative
  129. return std::max(0.0, reward);
  130. }
  131.  
  132. private:
  133. const double gravity;
  134. const double massCart;
  135. const double massPole;
  136. const double length; // actually half the pole's length
  137. const double forceMag;
  138. const double tau; // seconds between state updates
  139. const double thetaThresholdRadians;
  140. const double xThreshold;
  141. const double totalMass;
  142. const double polemassLength;
  143.  
  144. double state[4];
  145. bool done;
  146. };
  147.  
  148. class VelocityControlEnv {
  149. public:
  150. VelocityControlEnv()
  151. : targetVelocity(1.0), maxPosition(5.0), maxVelocity(2.0), tau(0.02), done(false) {
  152. reset();
  153. }
  154.  
  155. void reset() {
  156. position = 0.0;
  157. velocity = 0.0;
  158. done = false;
  159. }
  160.  
  161. void step(int action) {
  162. double force = (action == 0) ? -1.0 : 1.0; // Action: 0 = decrease velocity, 1 = increase velocity
  163. double acceleration = force; // Simplified acceleration model
  164.  
  165. // Update velocity and position
  166. velocity += tau * acceleration;
  167. position += tau * velocity;
  168.  
  169. // Apply velocity limits
  170. if (velocity > maxVelocity) velocity = maxVelocity;
  171. if (velocity < -maxVelocity) velocity = -maxVelocity;
  172.  
  173. // Apply position limits
  174. if (position > maxPosition || position < -maxPosition) {
  175. done = true;
  176. }
  177. }
  178.  
  179. bool isDone() const {
  180. return done;
  181. }
  182.  
  183. std::vector<double> getState() const {
  184. return {position, velocity};
  185. }
  186.  
  187. double getReward() const {
  188. // Reward is based on how close the velocity is to the target
  189. double velocityError = std::abs(velocity - targetVelocity);
  190. return std::max(0.0, 1.0 - velocityError / maxVelocity); // Reward is normalized between 0 and 1
  191. }
  192.  
  193. private:
  194. double position;
  195. double velocity;
  196. const double targetVelocity;
  197. const double maxPosition;
  198. const double maxVelocity;
  199. const double tau; // Time step
  200. bool done;
  201. };
  202.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement