/* File: AN.cpp */


#ifndef __AN_cpp__
#define __AN_cpp__

//#include <iostream.h>
//#include <cassert>
//#include <fstream.h>
#include "AN.h"







//using namespace UTILS;


namespace BIOS
{



  /************************/
  /* AN DEFINITION */
  /************************/


  /**
          @memo AN 
   
  	@doc
   
      @author M. Mar Abad Grau
  	@version 1.0
  */



  /*____________________________________________________________________________________ */

  AN::~AN()
  {
    zap(parentalDAG);
    zap(orderedAttributes);
    zap(sample);
  };
  /*____________________________________________________________________________________ */

  void AN::setParents()
  {
    if (parents==NULL || orderedAttributes==NULL)
    {
      cout <<"ERROR in AN::SetParents()";
      exit(0);
    }
    bool remove;
    NodeSet::iterator p=orderedAttributes->getFirst();
    int attribute, i=0;
    while (p!=orderedAttributes->end())
    {
      attribute=orderedAttributes->getElement(p)->getValue();
      if (parents[attribute]->size()==2)
       removeTANParent(i);
      addParents(i);
       // if (!removeTANParent(i)) addParents(i);
      p=orderedAttributes->getNext(p);
      i++;
    };
  };
  /*____________________________________________________________________________________ */

  void AN::setParentalDAG()
  {

    parentalDAG=new SimpleDAG((SimpleDAG&)*parentalTree);

  }
  /*____________________________________________________________________________________ */

  bool AN::removeTANParent(int position)
  {
    int i=0, bestAttribute, attribute=orderedAttributes->getElement(position)->getValue();
    float predictedError=0, minError=maxreal;
    //boolean
    if (parents[attribute]->size()!=2)
    {
      cout <<"Error in AN::removeTANParent";
      end();
    }
    int tanParent=parents[attribute]->getElement((int)1);
    parents[attribute]->removeNode((int)1);
 //cout <<"\nERR: att" << attribute <<" with parent " << tanParent << ":" <<  getError(attribute, tanParent, true) <<" without tan:" << getError(attribute, -1, true);
 
    if (getError(attribute, tanParent, true)<getError(attribute, -1, true))
    {
        parents[attribute]->insertElement(tanParent);
      return false;
    }
    else return true;
  }
  /*____________________________________________________________________________________ */

  void AN::addParents(int position)
  {
    int i=0, bestAttribute, lastBestAttributePos, lastBestAttribute=-1, attribute=orderedAttributes->getElement(position)->getValue();
    float predictedError=0, minError=maxreal;
    //boolean
    while (i<position && bestAttribute!=-2)
    {
      bestAttribute=chooseOptimalAttribute(position, minError);
      if (bestAttribute>=0) 
      {
                 parents[attribute]->insertElement(bestAttribute);
		 lastBestAttributePos=parents[attribute]->size()-1;
		 lastBestAttribute=bestAttribute;
		 }
      i++;
    }
    
    
     if (lastBestAttribute>=0)
     {
      parents[attribute]->removeNode((int)lastBestAttributePos);
 
      if (getError(attribute, lastBestAttribute, false)<=minError)
            parents[attribute]->removeNode((int)0);
	    
      parents[attribute]->insertElement(lastBestAttribute);
      }
      
  }
  /*____________________________________________________________________________________ */

  float AN::getError(int attribute, int newParent, bool withClass)
  {
    float confidence=0.90;
    intList * varList=new intList(), *conditionalList=new intList();
    varList->insertElement(attribute);
    int element;
    double d=1;
    int level=parents[attribute]->size();
     
    int totalCombinations=1;
    Attribute *a;
    
 
    for (int i=0;i<parents[attribute]->size();i++)
    {
      element=parents[attribute]->getElement(i);
      if (element!=totalAttributes-1 || withClass==true)
      {
        a=listOfAttributes->getElement((int)element); 
        conditionalList->insertElement(element);
        //d=d*sample->listOfAttributes->getElement(element)->GetTotalModalidades();
	totalCombinations=totalCombinations*a->getTotalModalidades();
      }
    }
      
    
    
    if (newParent>-1)
    {
      a=listOfAttributes->getElement((int)newParent);        
      conditionalList->insertElement(newParent);
      totalCombinations=totalCombinations*a->getTotalModalidades();   
      }
  //d=d*sample->listOfAttributes->getElement(element)->GetTotalModalidades();
	
    if (totalTrainingSampleSize<=totalCombinations) 
    {
     zap(conditionalList);
     zap(varList);
     return maxreal;
    }
      
    if (conditionalList->size()==0)     
    {
  //  cout <<"non cond, only att ";
    zap(conditionalList);
    }
    //d=d*sample->listOfAttributes->getElement(newParent)->GetTotalModalidades();
    
    // if (conditionalList->size()==1)     
    //   cout <<"\ncond for att " << attribute << " is " << conditionalList->getFirstElement();
   

    float measure=sample->getMeasure(stopCriterion, varList, conditionalList);
 
    switch(stopCriterion)
    {
    case mMDL: cout <<"Error in AN::getAccuracy: Non implemented yet"; exit(0); break; 
    case mSRM:
    d=totalCombinations*log(sample->listOfAttributes->getElement(element)->getTotalModalidades());

      //  d=combinations(totalSelectedAttributes, level);
      d=(d-log(confidence))/(2*sample->getSize());
      measure= measure + (d/(float)2)*((float)1+sqrt((float)1+(float)4*measure/d)); break; 
    }
    

    return measure;
  };
  /*____________________________________________________________________________________ */

  int AN::chooseOptimalAttribute(int position, float &originalError)
  {
    // if error does not dicrease by adding any attribute, return -1 or -2
    // -1 when error is not much higher than without adding an attribute
    // -2 when error is much higher than without adding an attribute (basically because complexity exponentially increase)
    int bestAttribute, attribute=orderedAttributes->getElement(position)->getValue(), currentAttribute;
    float predictedError, minError=maxreal;
    for (int i=0; i<position;i++)
    {
      currentAttribute=orderedAttributes->getElement(i)->getValue();
      if (currentAttribute!=attribute && parents[attribute]->findElement(currentAttribute)==parents[attribute]->end()) // neither the current attribute nor a parent (besides the class)
      {
        predictedError=getError(attribute, currentAttribute);
        if (predictedError<minError)
        {
          minError=predictedError;
          bestAttribute=currentAttribute;
        }
      }

      if (minError<originalError)
      {
        originalError=minError;
        return bestAttribute;
      }
      else
        if (minError<2*originalError)
          return -1; //not added, but algorithm continues
        else return -2; // finish node search
    }

  }
/*____________________________________________________________________________________ */

void AN::set()
{
if (!directMethod) 
undirectedBN->setJunctionTree();
//NB<T>::setClassifier();
//parents=Initialize(ClassAttribute::GetModalidades(), NULL);
}


  
  /*____________________________________________________________________________________ */

  AN::AN(floatMLSample* sample, int classPosition, floatList* algorithm, VerbosityClass *verbosity, LossFunction* lossFunction):TAN(sample, classPosition, algorithm, verbosity, lossFunction)
  {
    		this->sample=sample->clone();
		this->sample->removePatternsWithMissingAttribute ( classPosition );
		totalTrainingSampleSize=this->sample->size();
	TAN::set();
    parentalDAG=NULL;
    if (algorithm->size()<2) stopCriterion=mSRM;
    else stopCriterion=(MeasureType)(int)algorithm->getElement(1);// first is Bayes prior, second is the stopCriterion
//  BNC::set();
  }
/*____________________________________________________________________________________ */

  AN::AN():TAN()
  {
  }



  /*___________________________________________________________ */

  char* AN::print()
  {
    sprintf(line, "AN with alpha=%0.2f", alpha);
    return line;
  }

}
;  // Fin del Namespace

#endif

/* Fin Fichero: AN.h */
