#include <stdlib.h>
#include <memory.h>
#include <math.h>
#include "mathutil.h"
#include "Lazy.h"
/*
	Lazy(Table *t, VectorI)
		init the example table from a basic table
	eval(double *Q,double *out, double *model_parameter);
		eval the output of the lazy algorithm
	setWvec(double *Wvec)
		set the Wvec vector
	unsetWvec()
		unset the Wvec vector
	setValM(long val0,long val1,long val2);
		set the ValM vector
    unsetValM();
		Unset the ValM vector		
	setComb(long d);
	setComb3(long comb1,long comb2,long comb3)	
		set the combination of the different models
	setLAMBDA(double LAMBDA);
		set Lambda
    unsetLAMBDA();		
*/

// #define L1METRIC 1

Lazy::Lazy(Table *_x, VectorI *_range): 
  X(_x), Z1(NULL), Z2(NULL), W(NULL), BestDist(NULL), 
  bestPoints(NULL), BestA(NULL), Best0(NULL), Best1(NULL), 
  Best2(NULL), g(NULL)
{
    int nz,nz2,i;
    double *Vvec1, *Vvec2;
    
    LAMBDA=1e6;
    acbA=1;
    Wvec=NULL;
    val0=0;
    val1=0;
    val2=0;
    setKernel(_range);

    nz=X->n;
    nz2=(nz)*(nz+1)/2;

    // work vector for point inside kernel search
    C = (double*)calloc(nz-1,sizeof(double));

    // work vector for linear model
    Vvec1 = (double*)calloc(nz*nz,sizeof(double));
    v1 = (double**)calloc(nz,sizeof(double*));
    for (i=0; i<nz; i++, Vvec1+=nz) v1[i] =Vvec1;
    t1  = (double*)calloc(nz,sizeof(double));
    a1  = (double*)calloc(nz,sizeof(double));

    // work vector for quadratic model
    Vvec2 = (double*)calloc(nz2*nz2,sizeof(double));
    v2 = (double**)calloc(nz2,sizeof(double*));
    for (i=0; i<nz2; i++, Vvec2+=nz2) v2[i] = Vvec2;
    t2 = (double*)calloc(nz2,sizeof(double));
    a2 = (double*)calloc(nz2,sizeof(double));
}

Lazy::~Lazy()
{
    free(C);
    free(v1[0]); free(v1); free(t1); free(a1);
    free(v2[0]); free(v2); free(t2); free(a2);
    if (Wvec!=NULL) free(Wvec);
    freeStuff();
}

BOOL Lazy::tableHasChanged()
{
    return checkAndAllocateMemory();
}

VectorI *Lazy::getKernel()
{
    VectorI *v=new VectorI(6);
    (*v)[0]=idm0a;
    (*v)[1]=idM0a;
    (*v)[2]=idm1a;
    (*v)[3]=idM1a;
    (*v)[4]=idm2a;
    (*v)[5]=idM2a;
    return v;
}

BOOL Lazy::setKernel(VectorI *range)
{
    int _idm0a,_idm1a,_idm2a,_idM0a,_idM1a,_idM2a;
    int nz=X->n,i;
    if (range->n==3)
    {
        i = nz*(nz+1)/2;
        _idm0a=3*(*range)[0];
        _idm1a=3*nz*(*range)[1];
        _idm2a=3*i*(*range)[2];
        _idM0a=5*(*range)[0];
        _idM1a=5*nz*(*range)[1];
        _idM2a=5*i*(*range)[2];
    } else
    {
        _idm0a=(*range)[0];
        _idM0a=(*range)[1];
        _idm1a=(*range)[2];
        _idM1a=(*range)[3];
        _idm2a=(*range)[4];
        _idM2a=(*range)[5];
    }
    if ((_idm0a!=idm0a)||(_idM0a!=idM0a)||
        (_idm1a!=idm1a)||(_idM1a!=idM1a)||
        (_idm2a!=idm2a)||(_idM2a!=idM2a))
    {
        idm0a=_idm0a; idm1a=_idm1a; idm2a=_idm2a; 
        idM0a=_idM0a; idM1a=_idM1a; idM2a=_idM2a;
        return checkAndAllocateMemory();
    }
    return TRUE;
}

void Lazy::setValM(long _val0,long _val1,long _val2)
{
    val0 = MMAX(_val0,0);
    val1 = MMAX(_val1,0);
    val2 = MMAX(_val2,0);

    val0 = (val0*(val0<2))? 2 : val0;
    val1 = (val1*(val1<2))? 2 : val1;
    val2 = (val2*(val2<2))? 2 : val2;
};

void Lazy::unsetValM()
{
    val0=0;
    val1=0;
    val2=0;
}

void Lazy::setWvec(double *_Wvec)
{
    long t=(X->n-1)*sizeof(double);
    if (Wvec==NULL) Wvec=(double*)malloc(t);
    memcpy(Wvec,_Wvec,t);
};

void Lazy::unsetWvec()
{
    if (Wvec!=NULL) free(Wvec);
    Wvec=NULL;
};

BOOL Lazy::setComb(VectorI* v)
{   
    int _acbA,_acb0,_acb1,_acb2;
    if (v->n==1)
    {
        if (acbA!=(*v)[0])
        {
            acbA=(*v)[0];
            return checkAndAllocateMemory();
        };
        return TRUE;
    }
    _acbA=0;
    _acb0=(*v)[0];
    _acb1=(*v)[1];
    _acb2=(*v)[2];
    if ((_acbA!=acbA)||(_acb0!=acb0)||(_acb1!=acb1)||(_acb2!=acb2))
    {
        acbA=_acbA; acb0=_acb0; acb1=_acb1; acb2=_acb2;
        return checkAndAllocateMemory();
    }
    return TRUE;
};

BOOL Lazy::setComb(int c)
{
    VectorI v(1);
    v[0]=c;
    return setComb(&v);
}

VectorI *Lazy::getComb()
{
    if (acbA) 
    {
        VectorI *v=new VectorI(1);
        (*v)[0]=acbA;
        return v;
    }
    VectorI *v=new VectorI(3);
    (*v)[0]=acb0;
    (*v)[0]=acb1;
    (*v)[0]=acb2;
    return v;
}

void Lazy::freeStuff()
{    
    if (Z1)       
    {
        if (Z1[0]) free(Z1[0]); 
        free(Z1); Z1=NULL;
    }
    if (Z2)       
    {
        if (Z2[0]) free(Z2[0]); 
        free(Z2); Z2=NULL;
    }
    if (W)        { free(W); W=NULL; }
    if (BestDist) { free(BestDist); BestDist=NULL; }
    if (bestPoints) { delete bestPoints; bestPoints=NULL; }
    if (BestA) { delete BestA; BestA=NULL; }
    if (Best0) { delete Best0; Best0=NULL; }
    if (Best1) { delete Best1; Best1=NULL; }
    if (Best2) { delete Best2; Best2=NULL; }
}

BOOL Lazy::checkAndAllocateMemory()
{
    int i,nz=X->n,mx=X->m,nz2;
    double *Zvec1,*Zvec2;

    idm0 = MMAX(idm0a,2);
//    idM0 = MMAX(idM0a,idm0);
    idM0 = MMIN(idM0a,mx);
  
    idm1 = MMAX(idm1a,2);
//    idM1 = MMAX(idM1a,idm1);
    idM1 = MMIN(idM1a,mx);
  
    idm2 = MMAX(idm2a,2);
//    idM2 = MMAX(idM2a,idm2);
    idM2 = MMIN(idM2a,mx);
    
    Go0  = ((idM0-idm0+1)>0)? 1 : 0;
    Go1  = ((idM1-idm1+1)>0)? 1 : 0;
    Go2  = ((idM2-idm2+1)>0)? 1 : 0;

    noBestDistIndx = MMAX(MMAX(idM0,idM1),idM2);
    
    if (!acbA)
    {
        cbA=0;
        cb0=MMIN(acb0,idM0-idm0+1);
        cb1=MMIN(acb1,idM1-idm1+1);
        cb2=MMIN(acb2,idM2-idm2+1);

        cb0*=Go0; cb1*=Go1; cb2*=Go2;
        Go0*=cb0; Go1*=cb1; Go2*=cb2;
    }
    else
    {
        i = Go0*(idM0-idm0+1)+Go1*(idM1-idm1+1)+Go2*(idM2-idm2+1);
        cbA = MMIN(acbA, i);
        cb0 = 0;
        cb1 = 0;
        cb2 = 0;
        if (cbA<=0) return FALSE;
    }

    freeStuff();

    Z1=(double**)calloc(noBestDistIndx+1,sizeof(double*));
    Zvec1 = (double*)calloc((noBestDistIndx+1)*nz,sizeof(double));    
    for (i=0; i<=noBestDistIndx; i++,Zvec1+=nz) { Z1[i]=Zvec1; *Zvec1=1; };

    W=(double*)calloc(noBestDistIndx+1,sizeof(double));
    BestDist = (double*)calloc(noBestDistIndx+1,sizeof(double));
    bestPoints= new KeepBests(noBestDistIndx,nz-1);

    if (Go0+Go1+Go2==0) return FALSE;

    if (Go0)
    {
        maxNZ=1;
        if (!cbA) Best0=new KeepBests(cb0+1);
    }

    if (Go1)
    {
        maxNZ=nz;
        if (!cbA) Best1=new KeepBests(cb1+1,nz);
    }

    if (Go2)
    {
        nz2=nz*(nz+1)/2;
        maxNZ=nz2;

        Zvec2 = (double*)calloc(idM2*nz2,sizeof(double));
        Z2 = (double**)calloc(idM2,sizeof(double*));
        for (i=0; i<idM2; i++,Zvec2+=nz2) Z2[i] = Zvec2;

        if (!cbA) Best2=new KeepBests(cb2+1,nz2);
    }

    if (cbA) BestA=new KeepBests(cbA+1,maxNZ);

    return TRUE;
}

void Lazy::setLAMBDA(double _LAMBDA)
{
    LAMBDA=_LAMBDA;
};

void Lazy::unsetLAMBDA()
{
    LAMBDA=1e6;
}

void Lazy::setDistanceCheck(BOOL (*_g)(double *,int))
{
    g=_g;
}

void Lazy::unsetDistanceCheck()
{
    g=NULL;
}

void Lazy::eval(double *Q,double *y_hat, double **t_hat2, double *err)
{
// mx: number of example
// nz: number of column
    KeepBests *bestPoints=this->bestPoints,
              *Best0=this->Best0,
              *Best1=this->Best1,
              *Best2=this->Best2,
              *BestA=this->BestA;
    double  *Wvec=this->Wvec,
           LAMBDA=this->LAMBDA,
               *C=this->C,
               *W=this->W,
        *BestDist=this->BestDist,
             **Z1=this->Z1;
    int      idm0=this->idm0, 
             idM0=this->idM0, 
             idm1=this->idm1, 
             idM1=this->idM1, 
             idm2=this->idm2, 
             idM2=this->idM2,
             val0=this->val0, 
             val1=this->val1, 
             val2=this->val2,
              cb0=this->cb0, 
              cb1=this->cb1, 
              cb2=this->cb2,
              cbA=this->cbA,
   noBestDistIndx=this->noBestDistIndx;
   
    int mx=X->m, nz=X->n, nx=nz-1, nz2=(nz*(nz+1))>>1;

    double dist,e, b, tmp, sse, eC, y, w, errorP, *t_hat=NULL, **XP=X->p;
    int vl, i, j, k, p, m;

    // the next 8 variables are re-initialised later if necessary
    double **Z2;
    double *Zvec2;
    double **v1;
    double *t1;
    double *a1;
    double **v2;
    double *t2;
    double *a2;

    if (Go0+Go1+Go2==0) { *err=1e6; return; };

    if (t_hat2) 
    {
        t_hat=*t_hat2=(double*)calloc(maxNZ,sizeof(double));
        memset(t_hat,0,maxNZ*sizeof(double));
        if (mx<nz) { *err=1e6; return; }
    }
    
    if (noBestDistIndx<1) { *err=1e6; return; }
    
// search the noBestDistIndx closest points to the query point and put them in Z and W
    bestPoints->reset();
    if (Wvec)
    {
        for (i=0; i<mx; i++)
        {
            dist=0.0;
            for (j=0; j<nx; j++)
            {
                C[j]=XP[j][i]-Q[j];
                #ifdef L1METRIC 
                    dist+=Wvec[j]*fabs(C[j]);
                #else
                    dist+=Wvec[j]*SQR(C[j]);
                #endif
            }
            bestPoints->add(dist,XP[nx][i], C);
        };
    } else 
    {
        for (i=0; i<mx; i++)
        {
            dist=0.0;
            for (j=0; j<nx; j++)
            {
                C[j]=XP[j][i]-Q[j];
                #ifdef L1METRIC 
                    dist+=fabs(C[j]);
                #else
                    dist+=SQR(C[j]);
                #endif
            }
            bestPoints->add(dist,XP[nx][i], C);
        }
    }

    for (p=0; p<noBestDistIndx; p++)
    {
        #ifdef L1METRIC
            BestDist[p]=bestPoints->getKey(p);
        #else
            BestDist[p]=sqrt(bestPoints->getKey(p));
        #endif
        W[p]=bestPoints->getValue(p);
        for(j=0;j<nx;j++) Z1[p][j+1]=bestPoints->getOptValue(p,j);
    }
        
    if ((!t_hat)&&(BestDist[0]==0))
    {
        *y_hat=W[0];
        *err=0;
        return;
    }

//        if ((BestDist[0]>0.1)||(BestDist[nx-1]>1)) { *err=1e6; return; };

    if ((g)&&(!(*g)(BestDist,noBestDistIndx))) { *err=1e6; return; };
        
    if (cbA) BestA->reset();

        if (Go0)
        {
            if (!cbA) Best0->reset();

            y = W[0];
            eC = 1;

            for (k=1; k<idM0; k++)
            {
                if (val0)
                {
                    y = (k * y + W[k])/(k+1);
                    e = 0;
//                    vl = (val0*(val0<k+1))? val0 : k+1;
                    vl = (val0<k+1)? val0 : k+1;
                    for (i=0;i<vl;i++) e+=pow(y-W[i],2);
                    eC = e * vl / (vl-1);
                }else
                {
                    eC = eC*(k+1)*pow(k-1,2)/pow(k,3) + pow(W[k]-y,2)/k;
                    y = (k * y + W[k])/(k+1);
                }

                if (k>=idm0-1)
                {
                    if (cbA)
                    {
                        if (t_hat) BestA->add(eC,y,y);
                        else BestA->add(eC,y);
                    }else Best0->add(eC,y);
                }
            }
        }


        if (Go1)
        {
            if (!cbA) Best1->reset();
            
            v1=this->v1;
            t1=this->t1;
            a1=this->a1;

            memset(v1[0],0,nz*nz*sizeof(double));
            for (j=0; j<nz; j++) v1[j][j] = LAMBDA;
            memset(t1,0,nz*sizeof(double));
        
            /*
            W(=W1) and Z1 are already initialised.
            */

            for (k=0; k<idM1; k++)
            {
                e = W[k];
                b = 1;
                for (i=0; i<nz; i++)
                {
                    tmp=0;
                    for(j=0; j<nz; j++) tmp += v1[j][i] * Z1[k][j];
                    a1[i] = tmp;
                    b += Z1[k][i] * tmp;
                    e -= Z1[k][i] * t1[i];
                }
                for (i=0; i<nz; i++)
                    for(j=0; j<nz; j++)
                        v1[j][i] -= a1[i] * a1[j] / b;
                for (i=0; i<nz; i++)
                {
                    tmp=0;
                    for(j=0; j<nz; j++)
                        tmp += v1[j][i] * Z1[k][j];
                    t1[i] += e * tmp;
                }

                if (k>=idm1-1)
                {
                    vl = (val1*(val1<k+1))? val1 : k+1;
                    sse=0;
                    for(m=0; m<vl; m++)
                    {
                        e = W[m];
                        b = 1;
                        for (i=0; i<nz; i++)
                        {
                            tmp=0;
                            for(j=0; j<nz; j++)
                                tmp += v1[j][i] * Z1[m][j];
                            b -= Z1[m][i] * tmp;
                            e -= Z1[m][i] * t1[i];
                        }
                        sse += pow(e/b,2);
                    }
                    eC = sse / (k+1);

                    if (cbA)
                    {
                        if (t_hat) BestA->add(eC,t1[0],t1,nz);
                        else BestA->add(eC,t1[0]);
                    } else
                    {
                        if (t_hat) Best1->add(eC,t1[0],t1);
                        else Best1->add(eC,t1[0]);
                    }
                }
            }
        }


        if (Go2)
        {
            
            if (!cbA) Best2->reset();

            v2=this->v2;
            t2=this->t2;
            a2=this->a2;

            memset(v2[0],0,nz2*nz2*sizeof(double));
            for (j=0; j<nz2; j++) v2[j][j] = LAMBDA;
            memset(t2,0,nz2*sizeof(double));

            Z2=this->Z2;
            Zvec2 = *Z2;
            for(i=0;i<idM2;i++)
            {
                *(Zvec2++) = 1.0;
                for(j=0;j<nx;j++) *(Zvec2++) = Z1[i][j+1];
                for(p=0;p<nx;p++)
                    for(m=p;m<nx;m++)
                        *(Zvec2++) = Z1[i][p+1]*Z1[i][m+1];
            };

            for (k=0; k<idM2; k++)
            {
                e = W[k];
                b = 1;
                for (i=0; i<nz2; i++)
                {
                    tmp=0;
                    for(j=0; j<nz2; j++) tmp += v2[j][i] * Z2[k][j];
                    a2[i] = tmp;
                    b += Z2[k][i] * tmp;
                    e -= Z2[k][i] * t2[i];
                }
                for (i=0; i<nz2; i++)
                    for(j=0; j<nz2; j++)
                        v2[j][i] -= a2[i] * a2[j] / b;
                for (i=0; i<nz2; i++)
                {
                    tmp=0;
                    for(j=0; j<nz2; j++)
                        tmp += v2[j][i] * Z2[k][j];
                    t2[i] += e * tmp;
                }

                if (k>=idm2-1)
                {
                    vl = (val2*(val2<k+1))? val2 : k+1;
                    sse=0;
                    for(m=0; m<vl; m++)
                    {
                        e = W[m];
                        b = 1;
                        for (i=0; i<nz2; i++)
                        {
                            tmp=0;
                            for(j=0; j<nz2; j++)
                                tmp += v2[j][i] * Z2[m][j];
                            b -= Z2[m][i] * tmp;
                            e -= Z2[m][i] * t2[i];
                        }
                        sse += pow(e/b,2);
                    }
                    eC = sse / (k+1);

                    if (cbA)
                    {
                        if (t_hat) BestA->add(eC,t2[0],t2,maxNZ);
                        else BestA->add(eC,t2[0]);
                    } else
                    {
                        if (t_hat) Best2->add(eC,t2[0],t2);
                        else Best2->add(eC,t2[0]);
                    }
                }
            }
        }

    y=0;
    w=0;
    errorP=0;
    if (cbA)
    {
        if (t_hat)
        {
            for ( j=0; j<maxNZ; j++ ) t_hat[j]=0;
            for(i=0;i<cbA;i++)
            {
                e= BestA->getKey(i);
                e = (e==0)? 1E-20 : e;
                for ( j=0; j<maxNZ; j++) t_hat[j]+=BestA->getOptValue(i,j)/e;
                y += BestA->getValue(i)/e;
                w += 1/e;
                errorP+=e;             
            }
        } else
        {        
            for(i=0;i<cbA;i++)
            {    
                e= BestA->getKey(i);
                e = (e==0)? 1E-20 : e;
                y += BestA->getValue(i)/e;
                w += 1/e;
                errorP+=e;                
            }
        }
        errorP/=cbA;
    }else
    {
        if (t_hat)
        {
            for ( j=0; j<maxNZ; j++ ) t_hat[j]=0;
            for(i=0;i<cb0;i++)
            {
                e= Best0->getKey(i);
                e = (e==0)? 1E-20 : e;
                t[0]+=Best0->getOptValue(i,0)/e;                    
                y += Best0->getValue(i)/e;
                w += 1/e;
                errorP+=e;                
            }

            for(i=0;i<cb1;i++)
            {
                e= Best1->getKey(i);
                e = (e==0)? 1E-20 : e;
                for ( j=0; j<nz; j++) t_hat[j]+=Best1->getOptValue(i,j)/e;
                y += Best1->getValue(i)/e;
                w += 1/e;
                errorP+=e;
            }

            for(i=0;i<cb2;i++)
            {
                e= Best2->getKey(i);
                e = (e==0)? 1E-20 : e;
                for ( j=0; j<nz2; j++) t_hat[j]+=Best2->getOptValue(i,j)/e;
                y += Best2->getValue(i)/e;
                w += 1/e;
                errorP+=e;
            }
        } else
        {        
            for(i=0;i<cb0;i++)
            {
                e= Best0->getKey(i);
                e = (e==0)? 1E-20 : e;
                y += Best0->getValue(i)/e;
                w += 1/e;
                errorP+=e;                
            }
            for(i=0;i<cb1;i++)
            {
                e= Best1->getKey(i);
                e = (e==0)? 1E-20 : e;
                y += Best1->getValue(i)/e;
                w += 1/e;
                errorP+=e;
            }
            for(i=0;i<cb2;i++)
            {
                e= Best2->getKey(i);
                e = (e==0)? 1E-20 : e;
                y += Best2->getValue(i)/e;
                w += 1/e;
                errorP+=e;
            }
        }
        errorP/=cb0+cb1+cb2;
    }
        
    if (t_hat)
        for (i=0;i<maxNZ;i++) t_hat[i]/=w;
    *y_hat=y/w;
    *err=errorP;
};
