/*
 * QLearning.java
 *
 * Created on 1 janvier 2003, 22:23
 */

/**
 *
 * @author  Vanden Berghen Frank
 * @version 1.1
 *
 * Thanks to Jens G. Balchen for the bug correction line 183
 *
 */
import java.awt.*;
import java.awt.event.*;

public class BotQLearning extends java.applet.Applet implements Runnable, ActionListener, KeyListener
{
    private static final int nStateArm=9,nStateHand=19;
    private static final double degreeArmMin=Math.toRadians(-30),
                                degreeArmMax=Math.toRadians(30), 
                                degreeHandMin=Math.toRadians(-150),
                                degreeHandMax=Math.toRadians(0),

                                heightBody=40, 
                                widthBody=80,
                                lengthArm=60,
                                lengthHand=40;                         
    private static final int ground=200,
                             frame_size_x=800,
                             frame_size_y=300;
            
    private double epsilon=.8, gamma=0.9, alpha=1.0;
    
    private static final double SQR(double a) { return a*a; } 
    
    private static final int armMinus=0, armPlus=1, handMinus=2, handPlus=3, indefinite=-1, nAction=4;
    private static final double BADValue=Double.NEGATIVE_INFINITY;

    java.util.Random generateur=new java.util.Random();
    private double Q[][][];
    private double Xpos=100;
    private int s1Cur=nStateArm/2, s2Cur=nStateHand/2;

    private double totalRewards=0.0;
    private int totalSteps=0, runBatch=0;
    boolean running=true;
    Button run,skipHigh,skipLow,stop,reset,epsDec,epsInc,gamDec,gamInc,alphaDec,alphaInc,resetQ;
    TextField tfEps,tfGamma,tfAlpha;
    Thread myThread=null;
    java.text.DecimalFormat doubleFormatter;

    public BotQLearning() {}
    
    public void resetQ()
    {
        int i,j,k;
        for (i=0; i<nStateArm; i++)
            for (j=0; j<nStateHand; j++) 
                for (k=0; k<nAction; k++) Q[i][j][k]=0;
        for (i=0; i<nStateArm; i++) Q[i][0][handMinus]=BADValue;
        for (i=0; i<nStateArm; i++) Q[i][nStateHand-1][handPlus]=BADValue;
        for (j=0; j<nStateHand; j++) Q[0][j][armMinus]=BADValue;
        for (j=0; j<nStateHand; j++) Q[nStateArm-1][j][armPlus]=BADValue;
    } 

    public void init()
    {

        Q=new double[nStateArm][nStateHand][nAction];
        resetQ();
    
        Color c = new Color(123,200,145);
        setBackground(c);
        setLayout(new BorderLayout());
    
        Panel p =new Panel();                          add("North",p);
        run     =new Button("Run");                  p.add(run);
        skipHigh=new Button("Skip 1000000 step");    p.add(skipHigh);
        stop    =new Button("Stop");                 p.add(stop);
        skipLow =new Button("Skip 30000 steps");     p.add(skipLow);
        reset   =new Button ("Reset speed counter"); p.add(reset);
                                                     p.add(new Label("      "));
        resetQ  =new Button ("Reset Q");             p.add(resetQ);

        Panel r =new Panel();                          add("South",r);
        epsDec  =new Button("eps--");                r.add(epsDec);
        tfEps   =new TextField(""+epsilon);          r.add(tfEps);
        epsInc  =new Button("eps++");                r.add(epsInc);
                                                     r.add(new Label("          "));
        gamDec  =new Button("gam--");                r.add(gamDec);
        tfGamma =new TextField(""+gamma);            r.add(tfGamma);
        gamInc  =new Button("gam++");                r.add(gamInc);
                                                     r.add(new Label("          "));
        alphaDec=new Button("alpha--");              r.add(alphaDec);
        tfAlpha =new TextField(""+alpha);            r.add(tfAlpha);
        alphaInc=new Button("alpha++");              r.add(alphaInc);
        
                                                    addKeyListener(this);
             run.addActionListener(this);       run.addKeyListener(this);
        skipHigh.addActionListener(this);  skipHigh.addKeyListener(this);
         skipLow.addActionListener(this);   skipLow.addKeyListener(this);
            stop.addActionListener(this);      stop.addKeyListener(this);
           reset.addActionListener(this);     reset.addKeyListener(this);
          resetQ.addActionListener(this);    resetQ.addKeyListener(this);
          epsDec.addActionListener(this);    epsDec.addKeyListener(this);
          epsInc.addActionListener(this);    epsInc.addKeyListener(this);
          gamDec.addActionListener(this);    gamDec.addKeyListener(this);
          gamInc.addActionListener(this);    gamInc.addKeyListener(this);
        alphaDec.addActionListener(this);  alphaDec.addKeyListener(this);
        alphaInc.addActionListener(this);  alphaInc.addKeyListener(this);

        tfEps.setEnabled(false);
        tfGamma.setEnabled(false);
        tfAlpha.setEnabled(false);
    
        java.text.DecimalFormatSymbols ds=new java.text.DecimalFormatSymbols();
        ds.setDecimalSeparator('.');
        doubleFormatter=new java.text.DecimalFormat("0.0",ds);  
    }
    
    private int chooseMovement(int s1, int s2)
    {
        int possibleChoice[]=new int[4];
        int nChoice=0;
        
        if (s1!=0)            { possibleChoice[nChoice]=armMinus;  nChoice++; }
        if (s1!=nStateArm-1)  { possibleChoice[nChoice]=armPlus;   nChoice++; }
        if (s2!=0)            { possibleChoice[nChoice]=handMinus; nChoice++; }
        if (s2!=nStateHand-1) { possibleChoice[nChoice]=handPlus;  nChoice++; }
        
        if (generateur.nextDouble()>= epsilon) 
        {
            // choose the best Direction
            double QBest=Q[s1][s2][possibleChoice[0]],v;
            int nBest=0,i;

            for (i=1; i<nChoice; i++) 
            {
                v=Q[s1][s2][possibleChoice[i]];
                if (v>QBest) { nBest=i; QBest=v; }
            }
            return possibleChoice[nBest];
        }
        // choose a random Direction
        return possibleChoice[generateur.nextInt(nChoice)];
    }
    
    private double getMaxQ(int s1,int s2)
    {
        int i;
        double q=Q[s1][s2][0];
        for (i=1; i<nAction; i++) q=Math.max(q,Q[s1][s2][i]);
        return q;
    }
    
    double updateQ(int move)
    {
        int s1Old=s1Cur, s2Old=s2Cur;
        double oldStateValue, aReward;
        
        if (move==indefinite) move=chooseMovement(s1Cur,s2Cur);
        
        switch (move)
        {
            case armMinus:  s1Cur--; break;
            case armPlus:   s1Cur++; break;
            case handMinus: s2Cur--; break;
            case handPlus:  s2Cur++; break;
        }
//
// Inside the paper about Q-learning, we are using a special notation.
// The correspondence between this notation and the variables of the
// program is the following:
//           y = (s1cur,s2cur)
//           x = (s1Old,s2old)
//           a = move
//    Q^*(x,a) = Q[s1Old][s2Old][move]
//      V^*(y) = getMaxQ(s1Cur,s2Cur)
//           r = aReward
//
        oldStateValue=Q[s1Old][s2Old][move];
        aReward=reward(s1Old,s2Old,s1Cur,s2Cur);
        Xpos+=aReward;

        Q[s1Old][s2Old][move]=oldStateValue+alpha*(aReward
                                                   +gamma*getMaxQ(s1Cur,s2Cur)
                                                   -oldStateValue);
        return aReward;
    }
    
    double reward(int s1Old,int s2Old,int s1,int s2)
    {
        double t1=(degreeArmMax-degreeArmMin)/nStateArm,
               t2=(degreeHandMax-degreeHandMin)/nStateHand,
               degreeS1Old=s1Old*t1+degreeArmMin,
               degreeS2Old=s2Old*t2+degreeHandMin+degreeS1Old,
               degreeS1=s1*t1+degreeArmMin,
               degreeS2=s2*t2+degreeHandMin+degreeS1,
               
               xOld=lengthArm*Math.cos(degreeS1Old)+lengthHand*Math.cos(degreeS2Old)+widthBody,
               yOld=lengthArm*Math.sin(degreeS1Old)+lengthHand*Math.sin(degreeS2Old)+heightBody,
               x=lengthArm*Math.cos(degreeS1)+lengthHand*Math.cos(degreeS2)+widthBody,
               y=lengthArm*Math.sin(degreeS1)+lengthHand*Math.sin(degreeS2)+heightBody;

        if (y<0)
        {
            // y<0 ; yOld<0
            if (yOld<=0) return Math.sqrt(SQR(xOld)+SQR(yOld))-Math.sqrt(SQR(x)+SQR(y));
            
            // y<0 ; yOld>0
            return (xOld-yOld*(x-xOld)/(y-yOld))-Math.sqrt(SQR(x)+SQR(y));
        }
        // y>0 ; yOld>0
        if (yOld>=0) return 0;

        // y>0 ; yOld<0:
        return -(x-y*(xOld-x)/(yOld-y))+Math.sqrt(SQR(xOld)+SQR(yOld));
    }
    
    public void paint (Graphics g) 
    {
        super.paint(g);
        double degreeS1=s1Cur*(degreeArmMax-degreeArmMin)/nStateArm+degreeArmMin,
               degreeS2=s2Cur*(degreeHandMax-degreeHandMin)/nStateHand+degreeHandMin+degreeS1,
               x=lengthArm*Math.cos(degreeS1)+lengthHand*Math.cos(degreeS2)+widthBody,
               y=lengthArm*Math.sin(degreeS1)+lengthHand*Math.sin(degreeS2)+heightBody,
               degreeRot=0.0;
       
        if (y<0) degreeRot=Math.atan(-y/x);
        
        double cosRot=Math.cos(degreeRot), sinRot=Math.sin(degreeRot);

        int x1=(int)(Xpos), 
            y1=(int)(ground),
            x2=(int)(Xpos+cosRot*widthBody),
            y2=(int)(ground-sinRot*widthBody);
        double x3=Xpos-sinRot*heightBody,
               y3=ground-cosRot*heightBody,
               x4=x3+cosRot*widthBody,
               y4=y3-sinRot*widthBody,
               xArm=x4+lengthArm*Math.cos(degreeS1+degreeRot),
               yArm=y4-lengthArm*Math.sin(degreeS1+degreeRot);
        int ix4=(int)x4, iy4=(int)y4, 
            ixArm=(int)xArm, iyArm=(int)yArm, 
            ixHand=(int)(xArm+lengthHand*Math.cos(degreeS2+degreeRot)), 
            iyHand=(int)(yArm-lengthHand*Math.sin(degreeS2+degreeRot));
        
        g.setColor(new Color(100,120,12));
        g.drawLine(x1,y1,x2,y2);
        g.drawLine(x1,y1,(int)x3,(int)y3);
        g.drawLine(x2,y2,ix4,iy4);
        g.drawLine((int)x3,(int)y3,ix4,iy4);

        g.setColor(Color.orange);
        g.drawLine(ix4, iy4, ixArm, iyArm);
        g.drawLine(ix4+1, iy4+1, ixArm+1, iyArm+1);
        g.drawLine(ix4-1, iy4-1, ixArm-1, iyArm-1);
        g.drawLine(ix4+1, iy4-1, ixArm+1, iyArm-1);
        g.drawLine(ix4-1, iy4+1, ixArm-1, iyArm+1);

        g.setColor(Color.red);
        g.drawLine(ixArm, iyArm, ixHand, iyHand);
        g.drawLine(ixArm+1, iyArm+1, ixHand, iyHand);
        g.drawLine(ixArm-1, iyArm-1, ixHand, iyHand);
        g.drawLine(ixArm+1, iyArm-1, ixHand, iyHand);
        g.drawLine(ixArm-1, iyArm+1, ixHand, iyHand);

        g.setColor(Color.white);
        g.drawLine(0,ground+1,frame_size_x,ground+1);
        g.drawLine(0,ground+2,frame_size_x,ground+2);
        g.setColor(Color.lightGray);
        g.drawLine(0,ground+3,frame_size_x,ground+3);
        g.drawLine(0,ground+4,frame_size_x,ground+4);
        g.setColor(Color.gray);
        g.drawLine(0,ground+5,frame_size_x,ground+5);
        g.drawLine(0,ground+6,frame_size_x,ground+6);
        g.setColor(Color.darkGray);
        g.drawLine(0,ground+7,frame_size_x,ground+7);
        g.drawLine(0,ground+8,frame_size_x,ground+8);    
        
        if (totalSteps==0)
            g.drawString ("average speed : 0", 40, 50);
        else
            g.drawString ("average speed : "+(totalRewards/(double)totalSteps), 40, 50);

        if ((ixArm>frame_size_x)||(ixHand>frame_size_x)) Xpos=100;
        if (Xpos<0) Xpos=100;
    }

    private void startMyThread() 
    { 
        running=true;
        if (myThread==null) {myThread=new Thread(this); myThread.start();} 
    }    
    public void start() { runBatch=0; startMyThread(); }
    public void stop()  { running=false; }
    public void actionPerformed(java.awt.event.ActionEvent e)
    {
        Object s=e.getSource();
        if  (s==run)      start();
        if  (s==stop)     stop();
        if  (s==resetQ)   resetQ();
        if  (s==skipHigh) { runBatch=1000000; startMyThread(); }
        if  (s==skipLow)  { runBatch=30000;   startMyThread(); }
        if  (s==reset)    { totalRewards=0;   totalSteps=0;    }
        if ((s==epsDec  )&&(epsilon>0.01)) { epsilon-=0.1;   tfEps.setText(doubleFormatter.format(epsilon)); }
        if ((s==gamDec  )&&(gamma  >0.01)) { gamma  -=0.1; tfGamma.setText(doubleFormatter.format(gamma)); }
        if ((s==alphaDec)&&(alpha  >0.01)) { alpha  -=0.1; tfAlpha.setText(doubleFormatter.format(alpha)); }
        if ((s==epsInc  )&&(epsilon<0.99)) { epsilon+=0.1;   tfEps.setText(doubleFormatter.format(epsilon)); }
        if ((s==gamInc  )&&(gamma  <0.99)) { gamma  +=0.1; tfGamma.setText(doubleFormatter.format(gamma)); }
        if ((s==alphaInc)&&(alpha  <0.99)) { alpha  +=0.1; tfAlpha.setText(doubleFormatter.format(alpha)); }
        repaint();
    }
    
    public void keyTyped(java.awt.event.KeyEvent e) {}
    public void keyReleased(java.awt.event.KeyEvent e) {}
    public void keyPressed(java.awt.event.KeyEvent e) 
    {
        int key=e.getKeyCode();
        stop();
        switch (key) 
        {
            case KeyEvent.VK_RIGHT:  
               if (s2Cur!=nStateHand-1) { totalRewards+=updateQ(handPlus);  totalSteps++; } break;
            case KeyEvent.VK_LEFT:
               if (s2Cur!=0)            { totalRewards+=updateQ(handMinus); totalSteps++; } break;
            case KeyEvent.VK_UP:    
               if (s1Cur!=nStateArm-1)  { totalRewards+=updateQ(armPlus);   totalSteps++; } break;
            case KeyEvent.VK_DOWN:  
               if (s1Cur!=0)            { totalRewards+=updateQ(armMinus);  totalSteps++; } break;
        }
        repaint ();         
    }
    
    public void run()
    {
        Color c = new Color(123,200,145);
        while (running)
        {
            totalRewards+=updateQ(indefinite); totalSteps++;
            if (runBatch>0) 
            {
                if (runBatch%100==0)
                {
                    Graphics g=getGraphics();
                    g.setColor(c);
                    g.fillRect(550,40,50,10);
                    g.setColor(Color.darkGray);
                    g.drawString ("iteration before completion: "+runBatch, 400, 50);
                }
                runBatch--; 
            }
            else 
            {
                repaint();
                try { Thread.currentThread().sleep(100); } catch (InterruptedException e){}
            }
        }
        myThread=null;
    }
    
    public static void main(String args[])
    {
        final BotQLearning app= new BotQLearning(); 
        Frame aFrame= new Frame("Applet");
        aFrame.addWindowListener( new WindowAdapter() {
            public void windowClosing(WindowEvent e) { app.stop(); app.destroy(); System.exit(0);}
        });
        aFrame.add(app, BorderLayout.CENTER);
        aFrame.setSize(BotQLearning.frame_size_x,BotQLearning.frame_size_y);
        app.init();
        app.start();
        aFrame.setVisible(true);
    }
}
