/* File: CPT.cpp */


#ifndef __CPT_cpp__
#define __CPT_cpp__


#include "CPT.h"

using namespace std;

namespace BIOS
{



  /*______________________________________________________*/

  CPT::CPT()
  {
   marginals=NULL;
   conditionals=NULL;
    };
  /*______________________________________________________*/
/*
  CPT* CPT::operator/(CPT* source)
  {
  return this->product(source, false);

  };
  /*_______________________________________________________________*/
/*
  ProbabilityTable* CPT::operator*(ProbabilityTable* source)
  {
    // product is false when division
 return d=*this->conditional*source;
  };
  /*______________________________________________________*/

  CPT::CPT(CPT &source)
  {
prior=source.prior;
     marginals=new ProbabilityTable(*source.marginals);
     conditionals=new ProbabilityTable(*source.conditionals);
  };

   /*______________________________________________________*/

  CPT::CPT(intList *conditionalVarList, PriorTable *priorTable)//, floatList* alphaNumerators, float alphaDenominator, intSample::NodePointer first, intSample::NodePointer last)
   {
prior=true;
if (!empty(conditionalVarList))
{
this->marginals=priorTable->marginalize(conditionalVarList);
this->conditionals=priorTable->getConditional(conditionalVarList);
}
else 
{
this->marginals=new PriorTable(*priorTable);
this->conditionals=NULL;
}
distr=getProbabilityTable();
}
  /*______________________________________________________*/

  CPT::CPT(intMLSample*  sample, intList *varList, intList *conditionalVarList, intList* dimensionList, BayesType bayesType, float alpha)//, floatList* alphaNumerators, float alphaDenominator, intSample::NodePointer first, intSample::NodePointer last)
   {

prior=false;
float alphaDenominator=sample->listOfAttributes->getAlphaDenominator(bayesType, alpha, varList, conditionalVarList); 
//floatList* alphaNumerator=sample->listOfAttributes->getAlphaNumerator(bayesType, alpha, varList, conditionalVarList, ); 
ProbabilityTable* pT= new ProbabilityTable(sample, varList, dimensionList, NULL,  alphaDenominator);//, first, last); 


if (!empty(conditionalVarList))
{

//intList* conditionalDimensionList=sample->listOfAttributes->getDimensionList(conditionalVarList);
this->marginals=pT->marginalize(conditionalVarList);
//zap(conditionalDimensionList);
this->conditionals=pT->getConditional(conditionalVarList);


}
else 
{
this->marginals=new ProbabilityTable(*pT);
this->conditionals=NULL;
}

zap(pT);

distr=getProbabilityTable();

  };
 
  /*______________________________________________________*/

CPT::~CPT(){zap(conditionals); zap(marginals);};


  /*_______________________________________________________________*/


  ProbabilityTable* CPT::marginalize()
  {
if (conditionals!=NULL)
 return *marginals*conditionals;
else return new ProbabilityTable(*marginals);
  }
/*_______________________________________________________________*/


  ProbabilityTable* CPT::getMarginals()
  {
 return this->marginals;
   } 
/*_______________________________________________________________*/


  PriorTable* CPT::getPriorConditionals()
  {
 return (PriorTable*) this->conditionals;
   } 
/*_______________________________________________________________*/


  PriorTable* CPT::getPriorMarginals()
  {
 return (PriorTable*) this->marginals;
   } 
/*_______________________________________________________________*/


  ProbabilityTable* CPT::getConditionals()
  {
 return this->conditionals;
   } 
/*_______________________________________________________________*/


  PotentialTable* CPT::convertToPotential()
  {
VarsTable<double>* cp=new  VarsTable<double>(getProbabilityTable()->varList, getProbabilityTable()->dimensionList);
try{
for (int i=0;i<this->getSize();i++)
if (this->getValue(i).isUndefined()) cp->setValue(i, 1);
else cp->setValue(i, this->getValue(i).convert());
}
catch (MissingValue mv){mv.PrintMessage("CPT::convertToPotential"); end();};

return cp;

 // return getProbabilityTable()->convertToPotential();
   } 
 /*_______________________________________________________________*/


  void CPT::setMarginals(ProbabilityTable* marginals)
  {
this->marginals=marginals;
   } 
/*_______________________________________________________________*/


  void CPT::setConditionals(ProbabilityTable* conditionals)
  {
this->conditionals=conditionals;
   } 
/*_______________________________________________________________*/


  PriorTable* CPT::getPriorProbabilityTable()
  {
 if (conditionals!=NULL) return (PriorTable*)conditionals;
else return (PriorTable*)marginals;
   } 
/*_______________________________________________________________*/


  ProbabilityTable* CPT::getProbabilityTable()
  {
 if (conditionals!=NULL) return conditionals;
else return marginals;
   } /*_______________________________________________________________*/


  int* CPT::getPositions(int i)
  {
 return distr->getPositions(i);
   }
 /*_______________________________________________________________*/


  int CPT::getDimension()
  {
 return distr->getDimension();
   }
 /*_______________________________________________________________*/


  int CPT::getSize()
  {
 return distr->getSize();
   }

 /*_______________________________________________________________*/


  int CPT::getTotalSample()
  {
 return distr->totalSample;
   }
 /*_______________________________________________________________*/


  int CPT::getPos(int* positions)
  {
  return distr->getPos(positions);
   }
 /*_______________________________________________________________*/


  intList* CPT::getDimensionList()
  {
 return distr->dimensionList;
   }
 /*_______________________________________________________________*/
/*
  Prob CPT::getProbability(int pos, float alphaNumerator, float alphaDenominator)
{
  int *posTable=getProbabilityTable()->getPositions(pos);
  Prob result=getProbability(posTable, alphaNumerator, alphaDenominator);
  zap(posTable);
  return result;
}
 /*_______________________________________________________________*/

  Prob CPT::getValue(int pos)
{
  int *posTable=getProbabilityTable()->getPositions(pos);
  Prob result=getValue(posTable);
  zap(posTable);
  return result;
}
/*_______________________________________________________________*/

  Prob CPT::getValue(int* values)
  {

try{
return distr->getValue(values);
}
catch (MissingValue mv){mv.PrintMessage("CPT::getValue"); end();};

//Prob(distr->getValue(values).getNumerator(), distr->getValue(values).getDenominator());//indefined when true
}


  /*_______________________________________________________________*/
/*
  Prob CPT::getProbability(int* values, float alphaNumerator, float alphaDenominator)
  {
Prob p=this->getValue(values);
return Prob(p.getNumerator(), p.getDenominator(), alphaNumerator, alphaDenominator);//indefined when true
}
 
 */
    

     template<class T>  ostream& operator<<(ostream& out, CPT& p)
  {
  if (p.conditionals!=NULL)
{
   // string result=string("");
    int *pos=NULL;
    int pv;
    for (int i=0; i<p.conditionals->getSize();i++)
    {
      pos=p.conditionals->getPositions(i);
      out << "\nProb (";
     
      for (int j=0; j<p.conditionals->varList->GetSize();j++)
        out  <<"var " << p.conditionals->varList->GetElement(j) <<"= " <<  pos[j] <<", ";
       
    
        out << ") | Prob (";
        for (int j=0; j<p.marginals->varList->GetSize();j++)
        {
          pv=p.marginals->varList->GetPos(p.marginals->varList->findElement(p.marginals->varList->GetElement(j)));
          out << "var " << p.marginals->varList->GetElement(j) <<": " << pos[pv] <<", ";
         }
      

      out << ") = ";

      out << p.getValue(pos).print();

      zaparr(pos);
return out;
    };
}
else return out << p.marginals->print();
    

  };
/*______________________________________________________*/

    ostream& operator<<(ostream& out, CPT& p)
  {
if (p.prior)
{
if (p.getConditionals()==NULL) 
cout <<"\nPrior marginal probabilities:";
else
cout <<"\nPrior conditional probabilities:";
cout << *p.getPriorProbabilityTable();
if (p.getConditionals()!=NULL) 
{
cout <<"\nMarginal prior probabilities:";
cout << *p.getPriorMarginals();
}
}
else
{
if (p.getConditionals()==NULL) 
cout <<"\nMarginal sample frequencies:";
else cout <<"\nConditional sample frequencies:";
cout << *p.getProbabilityTable();
if (p.getConditionals()!=NULL) 
{
cout <<"\nMarginal sample frequencies:";
cout << *p.getMarginals();
}
}
}
}//end namespace

      /*______________________________________________________*/
/*
  ostream& operator<<(ostream& out, CPT& p)
  {
    out << p.print();
    return out;
  }
*/

#endif
