#ifndef ClassifierTest_cpp//
#define ClassifierTest_cpp//






//using namespace UTILS;

namespace BIOS {


/*___________________________________________________________ */

void ClassifierTest::set(TestMode testMod, int numberOfFolds, VerbosityClass *verbosity, int classPosition, floatMLSample* wholeSample, floatMLSample* trainingSample, floatMLSample* testSample, LossFunction* lossFunction)
{
if (verbosity->verbosityR.progress) 
cout << "\nLearning phase...";
completedClasses=NULL;
this->lossFunction=lossFunction;
this->trainingSample=NULL;
this->testSample=NULL;
this->verbosity=verbosity;
testMode=testMod;
this->numberOfFolds=numberOfFolds;
this->classPosition=classPosition;
this->originalClassPosition=classPosition;
this->algorithmParameters=NULL;
this->originalAlgorithmParameters=NULL;



switch (testMode)
{
case tLeaveOneOut:
case tCrossValidation:
if (trainingSample!=NULL || testSample!=NULL || wholeSample==NULL) 
throw BadFormat("Error 1 in ClassifierTest::ClassifierTest"); 
this->wholeSample=new floatMLSample(*wholeSample);
//this->wholeSample->removeMissingPatterns();
//clase=sample->sample->getElement(sample->sample->getFirst())->size()-1;
//sampleSize=sample->getSize();3
break;
case tTraining:
if (wholeSample!=NULL || trainingSample==NULL || testSample!=NULL)
throw BadFormat("Error 2 in ClassifierTest::ClassifierTest"); 
this->trainingSample=new floatMLSample(*trainingSample);
this->testSample=new floatMLSample(*trainingSample);
//this->trainingSample->removeMissingPatterns();
break;
case tHoldout:
if (testSample==NULL || wholeSample!=NULL || trainingSample==NULL)
throw BadFormat("Error 3 in ClassifierTest::ClassifierTest"); 
this->trainingSample=new floatMLSample(*trainingSample);
//this->trainingSample->removeMissingPatterns();
this->testSample=new floatMLSample(*testSample);
break;
}
selectedAttributes=NULL;
};
/*___________________________________________________________ */

ClassifierTest::ClassifierTest(floatMLSample* wholeSample, floatMLSample* trainingSample, floatMLSample* testSample, TestMode testMod, int numberOfFolds, VerbosityClass* verbosity, int classPosition, LossFunction* lossFunction)
{
set(testMod, numberOfFolds, verbosity, classPosition, wholeSample, trainingSample, testSample, lossFunction);

};
/*___________________________________________________________ */

ClassificationResults* ClassifierTest::getAveragedAccuracy(InputTUI * inputTUI)
{
return getAveragedAccuracy(inputTUI->algType, inputTUI->discMode, inputTUI->algorithmParameters, inputTUI->discretization, inputTUI->selMode, inputTUI->selection, NULL, inputTUI->stochastic);
}
/*___________________________________________________________ */

ClassificationResults* ClassifierTest::getAveragedAccuracy(AlgType algType, DiscMode discMod, floatList* algorithmParameters, floatList* discretization, SelMode selMode, floatList* selection, char* selectionFile, int stochastic)//, floatList* positions)
{
if (verbosity->verbosityR.progress) 
cout << "\nbuilding the model ...";
ClassificationResults* accuracy;

this->algorithmParameters=algorithmParameters;
switch(testMode)
{
case tLeaveOneOut: accuracy=LeaveOneOut(algType, discMod, discretization, selMode, stochastic, selection, selectionFile); break;
case tCrossValidation: accuracy=CrossValidation(algType, discMod, discretization, selMode,  stochastic, selection, selectionFile); break;
case tHoldout:  accuracy=Holdout(algType, discMod, discretization, selMode,  stochastic, selection, selectionFile); break;
//cout <<"later"; break;
case tTraining: accuracy=Training(algType, discMod, discretization, selMode,  stochastic, selection, selectionFile); break;
default: throw BadFormat("\nClassifierTest::getAveragedAccuracy. testMode"); break;
}
return accuracy;
};
/*___________________________________________________________________*/

ClassifierTest::~ClassifierTest()
{ 
zap(trainingSample);
zap(testSample);
zap(selectedAttributes);
zap(wholeSample);
zap(completedClasses);
};
/*___________________________________________________________ */

void ClassifierTest::selectAttributes(SelMode selMode, AlgType algType, floatList*selection, char* selectionFile)
{
try{
VerbosityClass* verbosity2=new VerbosityClass();
verbosity2->verbosityR.selectionScores=verbosity->verbosityR.selectionScores;
verbosity2->verbosityR.accuracy=true;
if (selMode!=todos)
{
classPosition=originalClassPosition;
floatMLSample* temp=NULL;
if (trainingSample!=NULL)
{
zap (selectedAttributes);
selectedAttributes=trainingSample->select(classPosition, algType, selMode, verbosity2, lossFunction, selection, selectionFile, algorithmParameters);//, positions);
zap(verbosity2);
if (selectedAttributes==NULL)
{
cout << "Error in ClassifierTest::selectAttributes";
throw NullValue();
end();
}
//cout <<*selectedAttributes;
trainingSample->listOfAttributes->select(selectedAttributes);
temp=trainingSample->copySelection();
zap(trainingSample);
trainingSample=temp;
}
else
{
cout <<"Error 2 in ClassifierTest::selectAttributes";
end();
}
testSample->listOfAttributes->select(selectedAttributes);
temp=testSample->copySelection();
zap(testSample);
testSample=temp;
if (wholeSample!=NULL)
{
wholeSample->listOfAttributes->select(selectedAttributes);
temp=wholeSample->copySelection();
zap(wholeSample);
wholeSample=temp;
}
int max=classPosition;
for (int i=0; i<max;i++)
  if (selectedAttributes->findElement(i)==selectedAttributes->end())
   classPosition--;
//zap(selectedAttributes);
}
}
catch (ZeroValue zv){zv.PrintMessage("ClassifierTest::selectAttributes");}
catch (NullValue nv){nv.PrintMessage("ClassifierTest::selectAttributes");}
catch (NonProb np){np.PrintMessage("ClassifierTest::selectAttributes");}
catch (BadFormat bf){bf.PrintMessage("ClassifierTest::selectAttributes");};
}
/*___________________________________________________________ */

floatMLSample* ClassifierTest::getTrainingSample()
{
return trainingSample;
}
/*___________________________________________________________ */

floatMLSample* ClassifierTest::getTestSample()
{
return testSample;
}
/*___________________________________________________________ */

floatMLSample* ClassifierTest::getSample()
{
return wholeSample;
}
/*___________________________________________________________ */

void ClassifierTest::discretizeAttributes(DiscMode discMod, AlgType algType, floatMLSample* sample)
{
if (discMod!=NoDiscretization)
{
sample->listOfAttributes->removeIntervals();
sample->setIntervals(discMod, true, verbosity, classPosition);// true is supervised
if (sample->listOfAttributes->isContinuous())
if (verbosity->verbosityR.attributeDescription) 
cout << "\n" << *sample->listOfAttributes;
}
else
if (!sample->listOfAttributes->isDiscretized() && AlgTypeClass::requiresDiscretization(algType))
 {
	 cout <<"Error in ClassifierTest::discretizeAttributes, sample " << sample->filename << " has continuous variables and you have chosen not to discretize.\n";
	 exit(0);
 }
//cout <<"m4";
}

/*____________________________________________________________________________________ */

Classifier* ClassifierTest::getClassifier(AlgType algType, DiscMode discMod, floatList* discretization, floatMLSample* sample)
{
//cout <<*trainingSample;
//cout <<"\nclasspos is" << classPosition;
try{
Classifier* classifier=NULL;

switch (algType)
{
case aNB: classifier=new NB(sample, classPosition, algorithmParameters, verbosity, lossFunction); break;
case aTAN: classifier=new TAN(sample, classPosition, algorithmParameters, verbosity, lossFunction); break;
case aGTAN: classifier=new AN(sample, classPosition, algorithmParameters, verbosity, lossFunction); break;
case aUAN: classifier=new UAN(sample, classPosition, algorithmParameters, verbosity, lossFunction); break;
case aC45: classifier=new C45(sample, classPosition, algorithmParameters, verbosity, lossFunction); break;
case aNN: classifier=new KNN(sample, classPosition, algorithmParameters, NULL, verbosity, lossFunction); break;
case aContinuousC45: classifier=new C45(sample, classPosition, algorithmParameters, discMod, verbosity, lossFunction); break;
case aEM: classifier=new EMClassifier<int>(sample, classPosition, algorithmParameters, verbosity, lossFunction); break;
default: {cout <<"\nError in ClassifierTest::getClassifier: This learning algorithm is not implemented yet."; end();}
}

return classifier;
}
catch (ZeroValue zv){zv.PrintMessage("ClassifierTest::getClassifier");end();}

}
/*____________________________________________________________________________________ */

ClassificationResults* ClassifierTest::CrossValidation(AlgType algType, DiscMode discMod, floatList* discretization, SelMode selMode, int stochastic, floatList* selectionList, char* selectionFile)
{

int sampleSize=wholeSample->getSize();
ClassificationResults* result=new ClassificationResults();
int aciertos=0, foldSize, resto, last;
testSample=NULL, trainingSample=NULL;
ClassificationResults** partialResult=new ClassificationResults*[numberOfFolds];
resto=sampleSize%numberOfFolds;
 last=0;
 Sampling *sampling;
 sampling=new Sampling(sampleSize, false);



//cout <<"numberoffolds:" << numberOfFolds;
for (int i=0;i<numberOfFolds;i++)
 {
//cout <<"\nNew fold";
 foldSize=sampleSize/numberOfFolds;
 if ((resto>0) && (i<resto)) foldSize=foldSize+1;
 intList* indexVector=new intList();
 for (int j=0;j<foldSize;j++) 
  indexVector->insertElement(sampling->Pos[i*foldSize+j]);
 trainingSample=new floatMLSample(*wholeSample);
testSample=trainingSample->extractRowsWithPositionsIn(indexVector);
zap(indexVector);
//cout <<"\npos:" << *trainingSample->listOfAttributes->positionsVector;


partialResult[i]=Holdout(algType, discMod, discretization, selMode, stochastic, selectionList, selectionFile);
//cout <<"FF:" << partialResult[i].First;
//partialResult[i]=new ClassificationResults();
zap(trainingSample);

zap(testSample);
  }
//cout <<"endfolds";

zap(sampling);

result->GetMean(partialResult, numberOfFolds);
zaparr(partialResult); //, numberOfFolds);
return result;

//return partialResult[0];
}
/*____________________________________________________________________________________ */

ClassificationResults* ClassifierTest::Holdout(AlgType algType, DiscMode discMod, floatList* discretization, SelMode selMode, int stochastic, floatList* selectionList, char* selectionFile)
{


if (algType==aContinuousC45 && !this->trainingSample->listOfAttributes->isContinuous())
throw BadFormat("\nClassifierTest::getAveragedAccuracy. We recommend to use non continuous C45 (-a=4), as there is no continuous atts in this data set\n");

if (testSample->listOfAttributes->getElement(classPosition)->isContinuous())
throw BadFormat(string("\nClassifierTest::getAveragedAccuracy. Attribute at position ")+tos(classPosition)+string(" is continuous so cannot be used as a class")); 
this->trainingSample->removePatternsWithMissingAttribute(classPosition);
this->originalAlgorithmParameters=algorithmParameters;
classPosition=originalClassPosition;
ClassificationResults* accuracy;
algorithmParameters=new floatList(*originalAlgorithmParameters);
//cout <<"\nclassNum is" << classPosition;
Classifier* classifier=NULL;
if (discMod!=NoDiscretization) discretizeAttributes(discMod, algType, trainingSample);
selectAttributes(selMode, algType, selectionList, selectionFile);

if (verbosity->verbosityR.progress) 

cout << "\nlearning parameters ...";


classifier=getClassifier(algType, discMod, discretization, trainingSample);

//cout << "stochastic:" << stochastic;
//if (verbosity.verbosityR.accuracy || verbosity.verbosityR.true_predictedValues
//||  verbosity.verbosityR.euclideanDistance)

if (verbosity->verbosityR.progress) 
cout << "\nTesting phase ...";
accuracy=classifier->getAccuracy(testSample, stochastic);
/*
if (completeSample)
{
zap(completedClasses);
completedClasses=classifier->completeMissing(testSample, stochastic);
}
*/
zap(classifier);
zap(algorithmParameters);
return accuracy;
}


/*____________________________________________________________________________________ */

ClassificationResults* ClassifierTest::Training(AlgType algType, DiscMode discMod, floatList* discretization, SelMode selMode, int stochastic, floatList* selectionList, char* selectionFile)
{
return Holdout(algType, discMod, discretization, selMode, stochastic, selectionList, selectionFile);
}

/*____________________________________________________________________________________ */

ClassificationResults* ClassifierTest::LeaveOneOut(AlgType algType, DiscMode discMod, floatList* discretization, SelMode selMode, int stochastic, floatList* selectionList, char* selectionFile)
{
numberOfFolds=wholeSample->getSize();
return CrossValidation(algType, discMod, discretization, selMode, stochastic, selectionList, selectionFile);
}
} // end namespace
#endif
