Guest User

Untitled

a guest
Jan 22nd, 2018
93
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 5.24 KB | None | 0 0
  1. import java.util.HashMap;
  2. import java.util.Random;
  3. import java.io.*;
  4.  
  5. public class QLearningController extends Controller {
  6.  
  7. TestPairs pairs = new TestPairs();
  8. double sumReward = 0.0;
  9. int nrTicks = 0;
  10. int nrWrites = 0;
  11.  
  12. public SpringObject object;
  13.  
  14. ComposedSpringObject cso;
  15. DoubleFeature x;
  16. DoubleFeature y;
  17. DoubleFeature vx;
  18. DoubleFeature vy;
  19. DoubleFeature angle;
  20.  
  21. RocketEngine left;
  22. RocketEngine middle;
  23. RocketEngine right;
  24.  
  25. boolean paused = false;
  26. String laststate, lastaction="0";
  27. double lastreward;
  28. HashMap<String, Double> Q = new HashMap<String, Double>();
  29. HashMap<String, Integer> Nsa = new HashMap<String, Integer>();
  30. final double GAMMA = 0.95;
  31. final double EPSILON = 0.1;
  32. final int MAX_PRIMITIVE_STEPS = 60;
  33. String currentstate;
  34. double currentReward;
  35. int actioncount=1;
  36. int tickcount=0;
  37. double rewardsum=0;
  38.  
  39. public void init() {
  40. cso = (ComposedSpringObject) object;
  41. x = (DoubleFeature) cso.getObjectById("x");
  42. y = (DoubleFeature) cso.getObjectById("y");
  43. vx = (DoubleFeature) cso.getObjectById("vx");
  44. vy = (DoubleFeature) cso.getObjectById("vy");
  45. angle = (DoubleFeature) cso.getObjectById("angle");
  46.  
  47. left = (RocketEngine) cso.getObjectById("rocket_engine_left");
  48. right = (RocketEngine) cso.getObjectById("rocket_engine_right");
  49. middle = (RocketEngine) cso.getObjectById("rocket_engine_middle");
  50.  
  51. }
  52.  
  53. void setNoBurst() {
  54. left.setBursting(false);
  55. right.setBursting(false);
  56. middle.setBursting(false);
  57. }
  58.  
  59. public void tick(int currentTime) {
  60.  
  61. if (! paused) {
  62.  
  63. currentstate=StateAndReward.getStateSimple(angle.getValue());
  64. currentReward=StateAndReward.getRewardSimple(vx.getValue(), vy.getValue(), angle.getValue());
  65.  
  66. if(currentstate.equals(laststate) && tickcount < MAX_PRIMITIVE_STEPS)
  67. {
  68. tickcount++;
  69. rewardsum+=currentReward;
  70.  
  71. }
  72. else
  73. {
  74. if(tickcount>0) currentReward = (double)rewardsum/tickcount;
  75. rewardsum=0;
  76. tickcount = 0;
  77. int nrTicksBeforeStat = 10000; // An example
  78.  
  79. if (nrTicks >= nrTicksBeforeStat) {
  80. TestPair p = new TestPair(nrTicksBeforeStat * nrWrites, (sumReward / nrTicksBeforeStat));
  81. pairs.addPair(p);
  82. try {
  83. writeToFile("output.m", pairs.getMatlabString("steps", "result"));
  84. } catch (Exception e) {
  85. e.printStackTrace();
  86. }
  87. sumReward = currentReward;
  88. nrTicks = 0;
  89. nrWrites++;
  90. } else {
  91. nrTicks++;
  92. sumReward += currentReward;
  93. }
  94.  
  95.  
  96. if (laststate!=null)
  97. {
  98. if(Nsa.containsKey(laststate+lastaction))
  99. {
  100. Nsa.put(laststate+lastaction, 1+Nsa.get(laststate+lastaction));
  101. }
  102. else Nsa.put(laststate+lastaction, 1);
  103.  
  104. double Qsa;
  105. if(Q.containsKey(laststate+lastaction))
  106. {
  107. Qsa = Q.get(laststate+lastaction);
  108. //System.out.println(Qsa);
  109. }
  110. else Qsa=0;
  111. System.out.println("Qsa: "+Qsa);
  112. double alpha = (double)1/(1+Nsa.get(laststate+lastaction));
  113. String bestaction = bestaction(currentstate);
  114. double Qsaprim;
  115. if(Q.containsKey(currentstate+bestaction)) Qsaprim = Q.get(currentstate+bestaction);
  116. else Qsaprim=0;
  117. double Qsanew = Qsa + alpha*(lastreward - Qsa + GAMMA*Qsaprim);
  118. System.out.println("Qsanew: "+Qsanew);
  119. Q.put(laststate+lastaction, Qsanew);
  120. //System.out.println(Qsa);
  121. }
  122.  
  123. lastreward = currentReward;
  124. lastaction = explore();
  125. performaction(lastaction);
  126. }
  127. laststate = currentstate;
  128. actioncount++;
  129.  
  130.  
  131. }
  132.  
  133. }
  134.  
  135.  
  136. private void performaction(String actionstring) {
  137. int action = Integer.valueOf(actionstring);
  138. if (action == 0){
  139. this.setNoBurst();
  140. middle.setBursting(true);
  141. }
  142. else if (action == 1){
  143. this.setNoBurst();
  144. right.setBursting(true);
  145. }
  146. else if (action == 2){
  147. this.setNoBurst();
  148. left.setBursting(true);
  149. }
  150. else if (action == 3){
  151. this.setNoBurst();
  152. }
  153. }
  154.  
  155. private String explore() {
  156. if (actioncount<5) return lastaction;
  157. actioncount=1;
  158. Random random = new Random();
  159. if(random.nextDouble()<EPSILON){
  160. int rand = random.nextInt(4);
  161. //System.out.println(rand);
  162. return String.valueOf(rand);
  163. }
  164.  
  165. return bestaction(currentstate);
  166. }
  167.  
  168. private String bestaction(String currentstate) {
  169. double best_value=0;
  170. int best_index=0;
  171. for(int i = 0;i<4;i++)
  172. {
  173. if(Q.containsKey(currentstate+String.valueOf(i))){
  174. //System.out.print(i);
  175. //System.out.print(" "+Q.get(currentstate+String.valueOf(i))+" ");
  176. if(Q.get(currentstate+String.valueOf(i))>best_value || best_value==0)
  177. {
  178. best_value = Q.get(currentstate+String.valueOf(i));
  179. best_index=i;
  180. //System.out.println(best_value);
  181. }
  182. }
  183. }
  184. //System.out.println(best_index);
  185. //System.out.print(" "+best_value);
  186. //System.out.println();
  187. return String.valueOf(best_index);
  188.  
  189. }
  190.  
  191. public void pause() {
  192. paused = true;
  193. setNoBurst();
  194. }
  195.  
  196.  
  197. public void run() {
  198. paused = false;
  199. }
  200.  
  201.  
  202. public void performCommand(String command) {
  203. super.performCommand(command);
  204. }
  205.  
  206. public void writeToFile(String filename, String content) {
  207. try {
  208. FileOutputStream fos = new FileOutputStream(filename);
  209. fos.write(content.getBytes());
  210. } catch (Exception e) {
  211. e.printStackTrace();
  212. }
  213. }
  214.  
  215. }
Add Comment
Please, Sign In to add comment