Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import java.util.HashMap;
- import java.util.Random;
- import java.io.*;
- public class QLearningController extends Controller {
- TestPairs pairs = new TestPairs();
- double sumReward = 0.0;
- int nrTicks = 0;
- int nrWrites = 0;
- public SpringObject object;
- ComposedSpringObject cso;
- DoubleFeature x;
- DoubleFeature y;
- DoubleFeature vx;
- DoubleFeature vy;
- DoubleFeature angle;
- RocketEngine left;
- RocketEngine middle;
- RocketEngine right;
- boolean paused = false;
- String laststate, lastaction="0";
- double lastreward;
- HashMap<String, Double> Q = new HashMap<String, Double>();
- HashMap<String, Integer> Nsa = new HashMap<String, Integer>();
- final double GAMMA = 0.95;
- final double EPSILON = 0.1;
- final int MAX_PRIMITIVE_STEPS = 60;
- String currentstate;
- double currentReward;
- int actioncount=1;
- int tickcount=0;
- double rewardsum=0;
- public void init() {
- cso = (ComposedSpringObject) object;
- x = (DoubleFeature) cso.getObjectById("x");
- y = (DoubleFeature) cso.getObjectById("y");
- vx = (DoubleFeature) cso.getObjectById("vx");
- vy = (DoubleFeature) cso.getObjectById("vy");
- angle = (DoubleFeature) cso.getObjectById("angle");
- left = (RocketEngine) cso.getObjectById("rocket_engine_left");
- right = (RocketEngine) cso.getObjectById("rocket_engine_right");
- middle = (RocketEngine) cso.getObjectById("rocket_engine_middle");
- }
- void setNoBurst() {
- left.setBursting(false);
- right.setBursting(false);
- middle.setBursting(false);
- }
- public void tick(int currentTime) {
- if (! paused) {
- currentstate=StateAndReward.getStateSimple(angle.getValue());
- currentReward=StateAndReward.getRewardSimple(vx.getValue(), vy.getValue(), angle.getValue());
- if(currentstate.equals(laststate) && tickcount < MAX_PRIMITIVE_STEPS)
- {
- tickcount++;
- rewardsum+=currentReward;
- }
- else
- {
- if(tickcount>0) currentReward = (double)rewardsum/tickcount;
- rewardsum=0;
- tickcount = 0;
- int nrTicksBeforeStat = 10000; // An example
- if (nrTicks >= nrTicksBeforeStat) {
- TestPair p = new TestPair(nrTicksBeforeStat * nrWrites, (sumReward / nrTicksBeforeStat));
- pairs.addPair(p);
- try {
- writeToFile("output.m", pairs.getMatlabString("steps", "result"));
- } catch (Exception e) {
- e.printStackTrace();
- }
- sumReward = currentReward;
- nrTicks = 0;
- nrWrites++;
- } else {
- nrTicks++;
- sumReward += currentReward;
- }
- if (laststate!=null)
- {
- if(Nsa.containsKey(laststate+lastaction))
- {
- Nsa.put(laststate+lastaction, 1+Nsa.get(laststate+lastaction));
- }
- else Nsa.put(laststate+lastaction, 1);
- double Qsa;
- if(Q.containsKey(laststate+lastaction))
- {
- Qsa = Q.get(laststate+lastaction);
- //System.out.println(Qsa);
- }
- else Qsa=0;
- System.out.println("Qsa: "+Qsa);
- double alpha = (double)1/(1+Nsa.get(laststate+lastaction));
- String bestaction = bestaction(currentstate);
- double Qsaprim;
- if(Q.containsKey(currentstate+bestaction)) Qsaprim = Q.get(currentstate+bestaction);
- else Qsaprim=0;
- double Qsanew = Qsa + alpha*(lastreward - Qsa + GAMMA*Qsaprim);
- System.out.println("Qsanew: "+Qsanew);
- Q.put(laststate+lastaction, Qsanew);
- //System.out.println(Qsa);
- }
- lastreward = currentReward;
- lastaction = explore();
- performaction(lastaction);
- }
- laststate = currentstate;
- actioncount++;
- }
- }
- private void performaction(String actionstring) {
- int action = Integer.valueOf(actionstring);
- if (action == 0){
- this.setNoBurst();
- middle.setBursting(true);
- }
- else if (action == 1){
- this.setNoBurst();
- right.setBursting(true);
- }
- else if (action == 2){
- this.setNoBurst();
- left.setBursting(true);
- }
- else if (action == 3){
- this.setNoBurst();
- }
- }
- private String explore() {
- if (actioncount<5) return lastaction;
- actioncount=1;
- Random random = new Random();
- if(random.nextDouble()<EPSILON){
- int rand = random.nextInt(4);
- //System.out.println(rand);
- return String.valueOf(rand);
- }
- return bestaction(currentstate);
- }
- private String bestaction(String currentstate) {
- double best_value=0;
- int best_index=0;
- for(int i = 0;i<4;i++)
- {
- if(Q.containsKey(currentstate+String.valueOf(i))){
- //System.out.print(i);
- //System.out.print(" "+Q.get(currentstate+String.valueOf(i))+" ");
- if(Q.get(currentstate+String.valueOf(i))>best_value || best_value==0)
- {
- best_value = Q.get(currentstate+String.valueOf(i));
- best_index=i;
- //System.out.println(best_value);
- }
- }
- }
- //System.out.println(best_index);
- //System.out.print(" "+best_value);
- //System.out.println();
- return String.valueOf(best_index);
- }
- public void pause() {
- paused = true;
- setNoBurst();
- }
- public void run() {
- paused = false;
- }
- public void performCommand(String command) {
- super.performCommand(command);
- }
- public void writeToFile(String filename, String content) {
- try {
- FileOutputStream fos = new FileOutputStream(filename);
- fos.write(content.getBytes());
- } catch (Exception e) {
- e.printStackTrace();
- }
- }
- }
Add Comment
Please, Sign In to add comment