/* File: UndirectedBN.cpp */


#ifndef __UndirectedBN_cpp__
#define __UndirectedBN_cpp__


#include "UndirectedBN.h"





//using namespace UTILS;


namespace BIOS
{

 
/*
________________________________________________________________________________
____ */

  UndirectedBN::UndirectedBN()
  {
    set();
  }

 
/*
________________________________________________________________________________
__ */

  UndirectedBN::UndirectedBN(intMLSample* sample, float alpha, intList**
parents, CPT** probTables)
  {
    typedef UndirectedArc<Node> UArc;
    typedef DirectedArc<Node> DArc;
    if (parents==NULL || sample==NULL || probTables==NULL)
    {
      cout <<"Error in UndirectedBN:set";
      end();
    }
    set();
    this->sample=sample;
    this->parents=parents;
    this->alpha=alpha;
    this->probTables=probTables;

    totalAttributes=sample->listOfAttributes->GetTotalAttributes();
   
totalSelectedAttributes=sample->listOfAttributes->getTotalSelectedAttributes();
    setNodeList();

    try
    {
      DAG<DArc, Node> *dAG=setDAG();
      UArc* uarc;
      UG<UArc, Node>* uG=dAG->getMoralGraph(uarc);
      zap(dAG);
      TriangulatedUG* triangulatedUG=new TriangulatedUG(*uG);
      zap(uG);

      cliques=new Set<Clique>(*triangulatedUG->cliques);
      zap(triangulatedUG);

      CompleteUG<Separator, Clique>* completeUG=new CompleteUG<Separator,
Clique>(cliques);

      UTree<Separator, Clique>* uTree=completeUG->getMWST();
//cout <<*uTree;
//end();
      zap(completeUG);
      Separator *s;
      junctionTree=uTree->getDirectedTree(uTree->nodes->getFirstElement(), s);

      zap(uTree);
//cout <<*junctionTree;
//end();
      this->setPotentials();
//cout <<*junctionTree;
end();
    
    }
    catch (NullValue NV ) {NV.PrintMessage("in UndirectedBN::UndirectedBN");}
  };
 
  /*____________________________________________________________________________________ */

  Tree<Separator, Clique>* UndirectedBN::getJunctionTree()
  {
   return junctionTree;
  };
/*____________________________________________________________________________________ */

  void UndirectedBN::set()
  {
    parents=NULL;
    junctionTree=NULL;
    nodeList=NULL;
    cliques=NULL;
commonConditionals=NULL;
  };
 
/*
________________________________________________________________________________
____ */

  void UndirectedBN::setNodeList()
  {
    nodeList=new Set<Node>();
    for (int i=0;i<totalAttributes;i++)
      if (sample->listOfAttributes->getElement(i)->isSelected())
        nodeList->insertElement(new Node(i));

  };

 
/*
________________________________________________________________________________
____ */

  UndirectedBN::~UndirectedBN()
  {
    zap(junctionTree);
    zap(nodeList);
    zap(cliques);
    zaparr(commonConditionals);
  };

 
/*
________________________________________________________________________________
____ */

  PotentialTable* UndirectedBN::createPotentialSeparator (intList* varList)
  {
    intList* dimensionList=sample->listOfAttributes->getDimensionList(varList);

    PotentialTable *potentialTable=new PotentialTable(varList,dimensionList, alpha);
    zap(dimensionList);
    //for(int i=0;i<PotentialTable->getSize();i++)
      //PotentialTable->setValue(i,1);
potentialTable->initialize(1);
    return potentialTable;
  }
 
/*____________________________________________________________________________________ */

  PotentialTable* UndirectedBN::createPotentialClique (intList* varList, intList** remainingParents, int cliqueNumber)
  {
 intList* dimensionList=sample->listOfAttributes->getDimensionList(varList);
//PotentialTable* potentialTable=new PotentialTable(varList, dimensionList, alpha), *pt2, *pt3;
PotentialTable* potentialTable=NULL, *pt2, *pt3;

int i=0;
//cout <<"\nobtaining potentials for clique " <<*varList <<"\n"; 
do
{
//if (parents[i]!=NULL && varList->findElement(i)!=NULL && varList->includes(parents[i]))
if (varList->findElement(i)!=NULL && remainingParents[i]!=NULL && (varList->includes(remainingParents[i])))
   {
     pt2=probTables[i]->convertToPotential();
    if (potentialTable==NULL) pt3=new PotentialTable(*pt2);
    else pt3=*potentialTable*pt2;
    zap(potentialTable);
    zap(pt2);
    potentialTable=pt3;
    zap(remainingParents[i]);
  }
  
//else commonConditionals[cliqueNumber]->insertElement(i);
//cout << "checking parents of " << i <<"\n";
i++;
}
while (i<totalAttributes);

//if(1==0)
if (potentialTable!=NULL && potentialTable->varList->size()!=varList->size())
{

intList* l=varList->copyElementsNotIn(potentialTable->varList);
intList* dl=sample->listOfAttributes->getDimensionList(l);
pt2=new PotentialTable(l, dl, potentialTable->alpha);
pt2->initialize(1);
pt3=*potentialTable*pt2;
zap(potentialTable);
potentialTable=pt3;
zap(pt2);
zap(l);
zap(dl);
//cout <<*potentialTable;

//cout <<"Error in UndirectedBN::createPotentialClique, varList is " << *varList << " while real potential varList is " << *potentialTable->varList;
//end();
}
/*
 if (varList->size()==5 && varList->getFirstElement()==1)
{
cout <<"here pot" << *varList;
cout <<*potentialTable;
end();
}
*/
if (potentialTable==NULL) 
{


potentialTable=new PotentialTable(varList, dimensionList, alpha);
potentialTable->initialize(1);
}
zap(dimensionList);
return potentialTable;
};
 
/*
________________________________________________________________________________
____ */

  PotentialTable* UndirectedBN::createPotential (Set<Node>* nodeSet, bool
isSeparator, intList** remainingParents, int cliqueNumber)
  {
    PotentialTable* potentialTable=NULL;

    if (nodeSet==NULL)
    {
      cout <<"Error in UndirectedBN::createPotential, null set of common nodes";
      end();
    }
    intList* varList=nodeSet->Container<Node, ListOfPointers>::getIntListFromPointerList();

    if (isSeparator) potentialTable=createPotentialSeparator(varList);
    else potentialTable=createPotentialClique(varList, remainingParents, cliqueNumber);
    zap(varList);

    return potentialTable;

  };
 
/*
________________________________________________________________________________
____ */

  void UndirectedBN::setPotentials()
  {
   commonConditionals=new intList*[junctionTree->nodes->size()];
 for (int i=0;i<junctionTree->nodes->size();i++)
         commonConditionals[i]=new intList(); 

   intList** parents2=new intList*[totalAttributes];
   for (int i=0;i<totalAttributes;i++)
        if (sample->listOfAttributes->getElement(i)->isSelected())
         parents2[i]=new intList(*parents[i]);
 
    Set<Clique>::iterator pC=junctionTree->nodes->getFirst();
    Clique* clique;
    PotentialTable* potentialTable;
    int i=0;
    while (pC!=NULL)
    {
      clique=junctionTree->nodes->getElement(pC);
      potentialTable=createPotential((Set<Node>*)clique, false, parents2, i);
      clique->setPotential(potentialTable);
      zap(potentialTable);
      pC=junctionTree->nodes->getNext(pC);
      i++;
    }
zaparr(parents2);

    Tree<Separator, Clique>::iterator p=junctionTree->getFirst();
    Separator * s;

    while (p!=NULL)
    {
      s=junctionTree->getElement(p);
      if (s->commonNodes==NULL) {cout << "Error in UndirectedBN::setPotentials";end();}
      potentialTable=createPotential(s->commonNodes, true);
      s->setPotential(potentialTable);
      zap(potentialTable);
      p=junctionTree->getNext(p);
    }
//cout << *junctionTree->nodes;
//end();
    //junctionTree->nodes->modularPrint();
    // junctionTree->modularPrint();
};

 
/*
________________________________________________________________________________
____ */

  DAG<DirectedArc<Node>, Node>* UndirectedBN::setDAG()
  {
    DAG<DirectedArc<Node>, Node>* dAG;
    intList::iterator p;
    dAG=new DAG<DirectedArc<Node>, Node>();
    Node* node=NULL;
    Set<Node>::iterator pN=NULL, pN2=NULL;
    DirectedArc<Node>* arc;
    dAG->nodes=nodeList;
    if (nodeList==NULL) throw NullValue();
    for (int i=0;i<totalAttributes;i++)
      if (sample->listOfAttributes->getElement(i)->isSelected())
        if (parents[i]!=NULL)
        {
          node=new Node(i);
          pN=dAG->nodes->findElement(node);
          if (pN==NULL) throw NullValue();

          zap(node);
          p=parents[i]->getFirst();

          while (p!=NULL)
          {
            node=new Node(parents[i]->getElement(p));
            pN2=dAG->nodes->findElement(node);
            if (pN2==NULL) throw NullValue();

            zap(node);
            arc=new DirectedArc<Node>(0, pN2, pN);

            dAG->insertElement(arc);
            zap(arc);
            p=parents[i]->getNext(p);
          }
        }
    return dAG;
  }
 
/*
________________________________________________________________________________
____ */

  void UndirectedBN::removeInconsistenciesWithEvidence(int att, intList*
inputPattern, Tree<Separator,Clique>* junctionTree2)
  {
intList* l=new intList(); l->insertElement(24);
PotentialTable*e;
    intList* inputPattern2=NULL;
    PotentialTable* potential, *potential2, *pt2;
    Clique* clique;
    Set<Clique>::iterator pC=junctionTree2->nodes->getFirst();
int i=0;
    while (pC!=NULL)
    {
       clique=junctionTree2->nodes->getElement(pC);
//cout <<"c:" << *clique;
       potential=clique->getPotential();
       inputPattern2=inputPattern->copyElementsWithPositionsIn(potential->varList, false);

///////very important to compute probabilities with missing data in those cliques which do not include all the cpt they could because they are already included in other one, and then cliques are not marginals any more.
/*
if (commonConditionals[i]->size()>0)
for (int j=0;j<commonConditionals[i]->size();j++)
{
  pt2=probTables[commonConditionals[i]->getElement(j)]->convertToPotential();

  potential2=*potential*pt2;

}
else 
*/
//potential2=new PotentialTable(*potential);
//zap(potential);
potential->removeInconsistenciesWithEvidence(inputPattern2, 0);
/*
if (commonConditionals[i]->size()>0)
for (int j=0;j<commonConditionals[i]->size();j++)
{
  potential=*potential2/pt2;
}
else 
*/
//potential=new PotentialTable(*potential2);
//clique->setPotential(potential);
//zap(potential);
//      potential=clique->getPotential();
if (0==1)
 if (potential->varList->getFirstElement()==0 && potential->varList->size()==3)
//end();
 // if (potential->varList->getFirstElement()==22)
{
cout <<"\ninconss:" << *potential;
e=potential->marginalize(l);
cout <<"\naftermargin:" << *e;
e->normalize();
cout <<"\nAfternotm:" << *e;
}
       zap(inputPattern2);
       pC=junctionTree2->nodes->getNext(pC);
i++;
    }

 }
/*____________________________________________________________________________________ */
void UndirectedBN::finalUpdate(Tree<Separator,Clique>* junctionTree2, Separator* separator, Clique* cliqueSource, Clique* cliqueTarget)
{
  intList* intList;

  PotentialTable *separatorPotential,
*newSeparatorPotential, *cliquePotential, *newCliquePotential,
*new2CliquePotential;
//cout <<"\nPreviois clique is:" << cliqueTarget->print();

      separatorPotential=separator->getPotential();
       intList=separator->commonNodes->getIntListFromPointerList();
    cliquePotential=cliqueTarget->getPotential();

//cout << "before" <<*cliqueSource->getPotential();

   newSeparatorPotential=cliqueSource->getPotential()->marginalize(intList);
   //newSeparatorPotential->setTotalSample();
  // cout << "\nclique:" << cliquePotential->print();
  //  cout << "\nseparator:" << newSeparatorPotential->print();
//  cout << "after" <<*newSeparatorPotential;
//end();
   newCliquePotential=*cliquePotential*newSeparatorPotential;

//cout <<"\npartiC:" << newCliquePotential->print();
   new2CliquePotential=*newCliquePotential/separatorPotential;
   
   
//cout <<"\nsep:" << separatorPotential->print();

//newSeparatorPotential->normalize();
   separator->setPotential(newSeparatorPotential);
   
 //  cliqueTarget->normalize();
   cliqueTarget->setPotential(new2CliquePotential);
   
   if (1==0)
   if (cliqueTarget->getFirstElement()->getValue()==0)
   {
   cout <<"\ncurrentClique:" << *cliquePotential;
 //  cout <<"\nsep:" << *separatorPotential;
//cout <<"\nis marginalization of " << *cliqueSource->getPotential() <<" for vars in " << *intList;
//    cout <<"\nupdatedSep:" << *newSeparatorPotential;
  
      cout <<"\nupdatedClique:" << *new2CliquePotential;
      }

   
//cout <<"\nsepNewPot:" << newSeparatorPotential->print();

 zap(intList);
 zap(separatorPotential);
 zap(separatorPotential);
 zap(newSeparatorPotential);
 zap(cliquePotential);
 zap(newCliquePotential);
 zap(new2CliquePotential);

//cout <<"\nUpdated clique is:" << cliqueTarget->print();
//end();
}
/*____________________________________________________________________________________ */

void UndirectedBN::updateEvidence(Tree<Separator,Clique>* junctionTree2, Clique*
pointer, bool collect)
{
      
  Tree<Separator, Clique>::iterator p=junctionTree2->getFirst();
 
  Separator* separator;
 
  while (p!=NULL)
  {
    separator=junctionTree2->getElement(p);
    if (junctionTree2->getFirst(separator)==pointer)
    {
      if (collect) 
{
updateEvidence(junctionTree2, junctionTree2->getSecond(separator), true);
finalUpdate(junctionTree2, separator, junctionTree2->getSecond(separator), junctionTree2->getFirst(separator));
 //  end();
}
else
{
finalUpdate(junctionTree2, separator, junctionTree2->getFirst(separator), junctionTree2->getSecond(separator));
updateEvidence(junctionTree2, junctionTree2->getSecond(separator), false);
     }
}
    p=junctionTree->getNext(p);
  }
}

/*
____________________________________________________________________________________ */

double* UndirectedBN::getPosteriorProb(int att, intList* inputPattern)
{

//cout <<"\n" << inputPattern->print()<< "\t" <<"Att:" << att;
  Node* node=new Node(att);
  Tree<Separator,Clique>* junctionTree2=new Tree<Separator,Clique>();
  Set<Clique>* currentCliqueList=new Set<Clique>(*cliques);
  junctionTree2->nodes=currentCliqueList;
  junctionTree2->copyArcs(junctionTree);
  junctionTree2->setRoot(junctionTree2->nodes->getFirstElement());
  
  

  
  removeInconsistenciesWithEvidence(att, inputPattern, junctionTree2);

//cout <<"BEFORE:"; 
//cout <<"\njunctionTree2firstnode:" << *junctionTree2->nodes->getFirstElement();

 /*
cout <<"\njunctionTree2:" << junctionTree2->getFirstElement()->print();
cout <<"\nfirstnode:" << junctionTree2->getFirstElement()->getFirst()->element->print();
cout <<"\nsecondnode:" << junctionTree2->getFirstElement()->getSecond()->element->print();
//end();
*/
  
  updateEvidence(junctionTree2, junctionTree2->getRoot(), true); //collect

 //   cout <<"FFFFFFF";
//cout <<*junctionTree2->nodes->getFirstElement();
//end();

  /*
cout <<"and then:";
cout <<"\njunctionTree2:" << junctionTree2->getFirstElement()->print();
cout <<"\nfirstnode:" << junctionTree2->getFirstElement()->getFirst()->element->print();
cout <<"\nsecondnode:" << junctionTree2->getFirstElement()->getSecond()->element->print();
*/
//end();
  updateEvidence(junctionTree2, junctionTree2->getRoot(), false);//distribute
//end();

  Clique* clique=NULL;
 
clique=junctionTree2->nodes->getElement(junctionTree2->nodes->findFinalElement(node));


  zap(node);
  PotentialTable* cP=NULL, *conditionalProbability=clique->getPotential();
 
  
  conditionalProbability->setTotalSample(sample->sample->size());
  //cout << *conditionalProbability;
 // end();
   // cout <<"\nFinal clique:" << clique->print();
  //end();
  
  
  
   //if (conditionalProbability==NULL) cout <<"FFF";
  //cout << conditionalProbability->print();
  //end();

  intList* l=new intList();
  l->insertElement(att);


  cP=conditionalProbability->marginalize(l);
  //cP->setTotalSample(sample->sample->size());
cP->normalize();
  // cout <<"fin:" << *cP;
  //end();
  // end();
  zap(l);

  int s=cP->dimensionList->getFirstElement();
  if (s!=cP->getSize())
  {
    cout <<"error in UndirectedTree:getPosteriorProb";
    end();
  }
  double *result=new double[s];
//  cout << *cP;
//end();

  //cout <<"\nBFEOFRE:";

//cout <<"\ncp" << *cP;
  for (int i=0;i<s;i++)
  //  result[i]=cP->getProbability(&i).convert();
{
//result[i]=cP->getValue(&i);
if (cP->getProbability(&i).isUndefined())
result[i]=1;//1;
//cout <<"\t " << cP->getProbability(&i).print();
else 
{
//cout <<"\t " << cP->getValue(&i);
result[i]=cP->getProbability(&i).convert();
}
//cout <<result[i] <<"--";
}

 // normalize(result, s);
  
 //  cout <<"AFTER:";

 //  for (int i=0;i<s;i++)
  //  result[i]=cP->getProbability(&i).convert();
//{
//cout <<result[i] <<"--";
//}
//  end();

// for (int i=0;i<s;i++)
//    cout << "\n" << print(result[i]);
//end();
  zap(junctionTree2);
  zap(cP);
  zap(currentCliqueList);
  //cout <<"dd";
  return result;
}
/*___________________________________________________________ */




}
;  // Fin del Namespace

#endif

/* Fin Fichero: UndirectedBN.cpp */
