/***************************************
 * copyright (c) Vanden Berghen Frank  *
 * V 1.2                               *
 * *************************************/

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include "BAGFSC45.h"
#include "C45.h"
#include "textUtils.h"
#include "tools.h"

#ifdef __NO_DATASET__
#define __EVAL_ONLY__
#endif

#ifndef __EVAL_ONLY__

#define CrossValPart 10
#define K_resolution 10
// nombres de "pas" pour l'valuation du K optimum
#define K_retry 1
//  chaque pas, on ressaie 10 fois l'valuation.
#define B_max 100
#define B_min 10

ItemNo    MINOBJS   = 2;
double    CF = 0.25;

#endif

//=======================================================
//=														=
//=					C45 object							=
//=														=
//=======================================================

char *C45::nextItem(char *line,FILE *stream)
{
	char *tline;
	while ((fgets(line,3000,stream)!=NULL)&&(emptyline(line)));
	if (ferror(stream)) exit(254);
	tline=line;
	while ((*tline==' ')||(*tline=='|')||(*tline==9)) tline++;
	return tline;
};

void C45::load_work(FILE *stream,Node *pNode,int nClasses,char *line)
{
    pNode->ClassDist=NULL;
	char *tline=nextItem(line,stream);
	switch (*tline)
	{
	case 'i':
	case 'I': tline++;
			  if ((*tline!='f')&&(*tline!='F')) exit(245);
			  tline++;
			  pNode->Tested=(int)lire_double(&tline)-1;
			  if (pNode->Tested<0)
			  {
				  printf("invalid C45 (bad feature number).\n");
				  exit(244);
			  };
			  while (*tline==' ') tline++;
			  if (*tline!='<') exit(245);
			  tline++;
			  pNode->NodeType=ThreshContin;
			  pNode->Cut=lire_double(&tline);
			  pNode->lower=(Node *)malloc(sizeof(Node));
			  load_work(stream,pNode->lower,nClasses,line);
			  tline=nextItem(line,stream);
			  if (strncmp(tline,"else",4)!=0) exit(243);
			  pNode->upper=(Node *)malloc(sizeof(Node));
			  load_work(stream,pNode->upper,nClasses,line);
			  break;
	case 'c':
	case 'C': tline++;
			  if (strncmp(tline,"lass",4)!=0) exit(245);
			  tline+=4;
			  pNode->lower=NULL;
			  pNode->NodeType=LeafType;
			  pNode->Leaf=(ClassNo)lire_double(&tline);
			  pNode->Confidence=(double)lire_double(&tline);
			  MinClass=MIN(pNode->Leaf,MinClass);
			  if (pNode->Leaf>=nClasses)
			  {
				  printf("invalid C45 (bad class number).\n");
				  exit(242);
			  };
			  break;
	default:  printf("invalid C45 (invalid syntax).\n");
			  exit(243);
	};
};

C45::C45(FILE *stream, int nClasses)
{
	char buffer[3000];
	buffer[0]='\0';
	MinClass=255;
	head=(Node *)malloc(sizeof(Node));
	load_work(stream,head,nClasses,buffer);
	load_work2(head);
};

void C45::load_work2(Node *N)
{
    if (N->NodeType!=LeafType)
    {
        load_work2(N->lower);
        load_work2(N->upper);
    } else N->Leaf-=MinClass;
}

void C45::delete_work(Node *N)
{
	if (N->lower) 
	{
		delete_work(N->lower);
		delete_work(N->upper);
	};
	if (N->ClassDist) free(N->ClassDist);
	free(N);
};

ClassNo C45::eval(double *features, double *confidence)
{
	Node *N=head;
	while (N->NodeType!=LeafType)
	{
		if (features[N->Tested]<=N->Cut) N=N->lower;
		else N=N->upper;
	}
    if (confidence) *confidence=N->Confidence;
	return N->Leaf+MinClass;
};

#ifndef __EVAL_ONLY__

Node *BuildTree(double **_Item, int _MaxItem, int _MaxClass, 
				int _MaxAtt, Boolean *_FS, Boolean useWeight);
Boolean Prune(Node *T,double **_Item, int _nItem,int _MaxClass, int _MaxAtt, Boolean useWeight);

C45::C45(DataSet *D, Boolean *FS, Boolean prune, Boolean useWeight)
{
    MinClass=D->MinClass;
	head=BuildTree(D->Item,D->nItem,D->MaxClass,D->MaxAtt,FS, useWeight);
//    FILE *f=fopen("new_tree.txt","w"); save(f); fclose(f);
	if (prune) Prune(head,D->Item,D->nItem,D->MaxClass,D->MaxAtt, useWeight);
};

void C45::save_work(FILE *F, Node *N, short ind,char *buffer, int l)
{
//    fprintf(F,"%sitems: %i\n", buffer,N->Items);
//    fprintf(F,"%serrors: %0.0f\n", buffer,N->Errors);

	if (N->NodeType==LeafType)
	{
		fprintf(F,"%sclass %u %f\n",buffer,N->Leaf+MinClass, N->Confidence);
	} else
	{
		fprintf(F,"%sIf %u < %e\n",buffer,N->Tested+1,N->Cut);
		strcpy(buffer+l,"|   ");
		save_work(F,N->lower,ind+1,buffer, l+4);
		buffer[l]='\0';
		fprintf(F,"%selse\n",buffer);
		strcpy(buffer+l,"    ");
		save_work(F,N->upper,ind+1,buffer, l+4);
		buffer[l]='\0';
	};
};

void C45::save(FILE *F)
{
	char buffer[3000];
	buffer[0]='\0';
	save_work(F,head,0,buffer,0);
};

double C45::errorEstimate() 
{ 
    if (!head) return 1.0;
    return head->Errors/head->Items; 
}

void freeLocalVarForBuild();

#endif

C45::~C45()
{
	if (head!=NULL) delete_work(head);
#ifndef __EVAL_ONLY__
	freeLocalVarForBuild();
#endif
};



//=======================================================
//=														=
//=					BAGFSC45 object						=
//=														=
//=======================================================


ClassNo BAGFSC45::eval(double *features, double *confidence)
{	
	ItemCount k;
	ClassNo best=0;
	int i;

	memset(vote,0,(nClasses-MinClass)*sizeof(double));
	for (i=0; i<B; i++) vote[(TreeTable[i])->eval(features, NULL)-MinClass]++;

	k=vote[0];
	for (i=1; i<(nClasses-MinClass); i++) if (vote[i]>=k) { best=i; k=vote[i]; };
	
	if (confidence) *confidence=((double)k)/B;	// le pourcentage de C45 ayant vot correctement
    return best+MinClass;	    // la classe calcule de l'image
};

BAGFSC45::BAGFSC45(FILE *stream): vote(NULL)
{
    int j=0;
    char buffer[300],*tline=buffer;

    fgets(tline,300,stream); nClasses=atol(tline);
	fgets(tline,300,stream); B=atol(tline);

    TreeTable=(C45**)malloc(B*sizeof(C45*));
    if (TreeTable==NULL)
    {
        fprintf(stderr,"BAGFSC45: out of memory.\n");
        exit(210);
    };
	
	MinClass=255;
    for (j=0; j<B; j++)
    {
		TreeTable[j]=new C45(stream,nClasses);
        MinClass=MIN(TreeTable[j]->MinClass,MinClass);
    }
    vote=(double*)malloc((nClasses-MinClass)*sizeof(double));
    if (vote==NULL)
    {
        fprintf(stderr,"BAGFSC45: out of memory.\n");
        exit(210);
    };
};

BAGFSC45::~BAGFSC45()
{
    int j;
    if (name) free(name);
    if (vote) free(vote);
    for (j=0; j<B; j++) delete(TreeTable[j]);
    if (TreeTable) free(TreeTable);
};

#ifndef __EVAL_ONLY__

void BAGFSC45::save(FILE *f)
{
	int i;
	printf("saving BAGFS trees\n");
	fprintf(f,"BAGFS Trees v1.00 (c) Vanden Berghen Frank\n%i\n%i\n",
	        nClasses, B);
	for (i=0; i<B; i++)
	{
		printf(".");
		TreeTable[i]->save(f);
	};
	printf("\n");
};

void BAGFSC45::generateFS(Boolean *FS, int MaxAtt, int NumberOfAtt)
// cre le vecteur FS de boolean:
// - ce vecteur est de taille MaxAtt.
// - il y a 'NumberOfAtt' lments de FS qui sont TRUE, les autres sont FALSE.
{
	int i,m;
	memset(FS,0,MaxAtt*sizeof(Boolean));
	while (NumberOfAtt>0)
	{
		i=0;
		m=(int)(rand1()*MaxAtt);
        do
        {
			if (!FS[i]) m--;
            i++;
        } while(m>=0);
		FS[i-1]=1; NumberOfAtt--; MaxAtt--;
	};
};

#define ClassTest(Case)	 (*((ClassNo*)(ItemTest->Item[Case]+MaxAtt)))

int BAGFSC45::FindOptimalK(DataSet *AllDatas)
{
	int MaxAtt=AllDatas->MaxAtt,SQRMaxClass=SQR(AllDatas->MaxClass);
	int *conf=(int*)malloc(SQRMaxClass*K_resolution*sizeof(int));
	DataSet *ItemCreate,*ItemTest;
	int i,j,k,NumberOfAtt,NumberOfAttB,NumberOfAtt_old;
	BAGFSC45 *CurrentC;
	double t,kappaB;

    memset(conf,0,SQRMaxClass*K_resolution*sizeof(int));

	for (j=0; j<CrossValPart; j++)
	{

		printf("determining optimal K: crossval part %i/%i\n",j+1,CrossValPart);

		ItemCreate=AllDatas->generate_Create_Set(j, CrossValPart);
		ItemTest=AllDatas->generate_Test_Set(j, CrossValPart);

		NumberOfAtt_old=0;
		for (k=0; k<K_resolution; k++)
		{
			NumberOfAtt=MaxAtt*(k+1)/K_resolution;
			if (NumberOfAtt==NumberOfAtt_old) continue;

			printf(" testing K=%i attributes\n",NumberOfAtt);
			for (i=0; i<K_retry; i++)
			{
				printf("  pass %i / %i\n",i+1,K_retry);

				// generate corresponding tree
				CurrentC = new BAGFSC45(ItemCreate,NumberOfAtt,B_max);
				CurrentC->UpdateConfusionMatrix(conf+k*SQRMaxClass, ItemTest);
				delete (CurrentC);
			};
			NumberOfAtt_old=NumberOfAtt;
		};
		delete (ItemCreate); delete (ItemTest);
	};


	// find minimum of errors

	kappaB=-INF; NumberOfAtt_old=0;
	for (k=0; k<K_resolution; k++)
	{
		NumberOfAtt=MaxAtt*(k+1)/K_resolution;
		if (NumberOfAtt==NumberOfAtt_old) continue;	
    	if ((t=kappa(conf+k*SQRMaxClass,AllDatas))>kappaB)
		{
			kappaB=t; NumberOfAttB=NumberOfAtt;
		};
	};
	free(conf);
	return NumberOfAttB;
};

void BAGFSC45::fillTreeTable(int n, int K, DataSet *D)
{
	Boolean *FS=(Boolean*)malloc(D->MaxAtt*sizeof(Boolean));
	int i;
	DataSet *bs;

	B=n;
    MinClass=D->MinClass;
	nClasses=D->MaxClass+D->MinClass;
    if (vote) free(vote);
    vote=(double*)malloc(D->MaxClass*sizeof(double));
	TreeTable=(C45**)malloc(n*sizeof(C45*));
	printf("   Building all %i trees.\n",B);
	for (i=0; i<B; i++)
	{
		printf("."); fflush(stdout);
		generateFS(FS,D->MaxAtt,K);
		if (B==1) TreeTable[i]= new C45(D,FS,1);
        else
        {
            bs=D->generate_Bootstrap();
		    TreeTable[i]= new C45(bs,FS,1);
		    delete(bs);
        }
	};
	printf("\n");
	free(FS);
};

void BAGFSC45::BuildAllTrees(int K, DataSet *D)
// dtermine B et rempli le tableau de C45
{
	int i=0,Bh=B_max,Bl1=MAX(1,B_min-5),Bl2=MAX(1,B_min-5),Bm;
	DataSet *TestSet,*CreateSet;
	BAGFSC45 *other;


	CreateSet=D->generate_Create_Set(0, 1/0.4);

		/*
// mthode incrmentale pour trouver B
// init until B=B_Min
		for (i=0; i<B_min; i++)
		{
			generate_bootstrap(&CreateSet,&bs);
			generateFS(FS,K);
			TreeTable[i]= BuildTree(&bs,FS);
			Prune(TreeTable[i],&bs);
		};

	    Bl=B-10;
		while (B<B_max)
		{
			generate_bootstrap(&CreateSet,&bs);
			generateFS(FS,K);
			TreeTable[B]= new C45(bs,FS,1);
			B++;
			if (!equivalent(TreeTable,B,Bl,TestSet) Bl=B-10;
			if (Bl<B-15) break;
		};
		return Bh;
*/

// mthode non-incrmentale pour trouver B
	fillTreeTable(B_max, K, CreateSet);
	delete CreateSet; 

	other=new BAGFSC45(this); // shallow copy
	TestSet=D->generate_Test_Set(0, 1/0.4); // testset is 40% of globalset

	printf("\nDetermining optimal B\n"
	       "  pass 1\n");
	while (Bh-Bl1>2)
	{
		Bm=(Bh+Bl1)/2;
		other->B=Bm;
		if (Different(other,TestSet)) Bl1=Bm; else Bh=Bm; 
	};
	printf("  pass 2\n"); 

	Bh=(int)(B_max*0.9); B=Bh;
	while (Bh-Bl2>2)
	{
		Bm=(Bh+Bl2)/2;
		other->B=Bm;
		if (Different(other,TestSet)) Bl2=Bm; else Bh=Bm; 
	};

    // because of shallow copy:
	other->vote=NULL; other->B=0; other->TreeTable=NULL; delete other; 

	delete TestSet;
    
	B=MIN((Bl1+Bl2)/2+5,B_max);
	printf("Optimal B= %i\n",B);
	for (i=B; i<B_max; i++) delete(TreeTable[i]);
};

BAGFSC45::BAGFSC45(DataSet *D, int k, int b): vote(NULL)
{
    if (k==-1) k=D->MaxAtt;
	if (!k)
	{ 
		k=FindOptimalK(D);
		printf("optimal K=%i / %i\n",k,D->MaxAtt);
	};
	if (b) fillTreeTable(b,k,D);
	else BuildAllTrees(k,D);
    K=k;
};

BAGFSC45::BAGFSC45(BAGFSC45 *t)
{
	*this=*t;
	name=NULL;
};

BAGFSC45::BAGFSC45(): B(0), nClasses(0), TreeTable(NULL), vote(NULL){};

double BAGFSC45::errorEstimate()
{
    double sum=0;
    int i;
    if (!B) return 1.0;

    for (i=0; i<B; i++) sum+=TreeTable[i]->errorEstimate();
    return sum/B;
}

#endif
