/* File: GBN.cpp */


#ifndef __GBN_cpp__
#define __GBN_cpp__






//using namespace UTILS;


namespace BIOS
{

	/*____________________________________________________________________________________ */

	GBN::GBN ( floatMLSample* sample, floatList* algorithm, VerbosityClass *verbosity, int alphaDefault )
	{
		init();
		this->discreteSample=sample->getDiscreteSample();
		if ( sample->listOfAttributes==NULL )
		{
			cout <<"error in GBN::GBN";
			end();
		}

		totalAttributos=sample->listOfAttributes->GetTotalAttributes();
		intList * inputPattern=NULL;
		this->verbosity=verbosity;
		alpha=alphaDefault;
		bayesType=MLE;//UBalpha;
		this->algorithm=algorithm;
		if ( algorithm==NULL )
			throw NullValue ( "GBN::GBN ( floatMLSample* sample, floatList* algorithm, VerbosityClass *verbosity, int alphaDefault )" );
		if ( algorithm->size() >=1 && algorithm->getElement ( 0 ) ==1.0 ) directMethod=true;
		if ( algorithm->size() >=2 ) bayesType= ( BayesType ) ( algorithm->getElement ( 1 ) );
		if ( bayesType!=MLE ) alpha=1;
		if ( algorithm->size() >=3 && bayesType!=MLE ) alpha=algorithm->getElement ( 2 );
		try
		{
			inputPattern=discreteSample->sample->getFirstElement();
			if ( totalAttributos!=inputPattern->size() )
				throw BadSize();
		}
		catch ( BadSize bs ) {bs.PrintMessage ( "GBN::GBN", totalAttributos, inputPattern->size() ); }
	};

	/*____________________________________________________________________________________ */

	GBN::GBN ( intMLSample* sample, floatList* algorithm, VerbosityClass* verbosity, int alphaDefault )
	{
		init();
		this->discreteSample=sample->clone();
		if ( sample->listOfAttributes==NULL )
		{
			cout <<"error in GBN::GBN";
			end();
		}


//if (discreteSample->listOfAttributes->size()>1)
//cout << *discreteSample->listOfAttributes;
//end();
//cout << *discreteSample;
		totalAttributos=sample->listOfAttributes->GetTotalAttributes();
		intList * inputPattern=NULL;
		this->verbosity=verbosity;
		alpha=alphaDefault;
		bayesType=MLE;//UBalpha;
		this->algorithm=algorithm;
//cout <<*algorithm;
		if ( algorithm->size() >=1 && algorithm->getElement ( 0 ) ==1.0 ) directMethod=true;
		if ( algorithm->size() >=2 ) bayesType= ( BayesType ) ( algorithm->getElement ( 1 ) );
		if ( bayesType!=MLE ) alpha=1;
		if ( algorithm->size() >=3 && bayesType!=MLE ) alpha=algorithm->getElement ( 2 );

		try
		{
			inputPattern=discreteSample->sample->getFirstElement();

			if ( totalAttributos!=inputPattern->size() )
				throw BadSize();
		}
		catch ( BadSize bs ) {bs.PrintMessage ( "GBN::GBN", totalAttributos, inputPattern->size() ); }

	};

	/*____________________________________________________________________________________ */

	void GBN::setUndirectedBN ( VerbosityClass *verbosity )
	{
//cout <<"fffffffffffffff\n";
		undirectedBN=new UndirectedBN ( discreteSample, alpha, parents, probTables, verbosity );
//cout <<*probTables <<"\n";
//cout <<*undirectedBN <<"\n";
//cout <<"loca\n";
	}
	/*____________________________________________________________________________________ */

	void GBN::init()
	{
		directMethod=false;
		parents=NULL;
		probTables=NULL;
		undirectedBN=NULL;
		clique=NULL;
		discreteSample=NULL;
	};
	/*____________________________________________________________________________________ */

	void GBN::set()
	{
		// setParents();
		//  printParents();

		if ( verbosity!=NULL && verbosity->verbosityR.structure )
			printParents();

		setProbTables();

		// if (directMethod)
		if ( directMethod ) createClique();
		else
		{
			setUndirectedBN ( verbosity );
		}


	};

	/*____________________________________________________________________________________ */

	void GBN::report()
	{
		if ( verbosity!=NULL && verbosity->verbosityR.structure ) // better to add "cliques" in verbosity record or complete structure
		{
			SetOfCliques* nods=undirectedBN->getJunctionTree()->nodes;
			SetOfCliques::iterator pC=nods->getFirst();
			while ( pC!=nods->end() )
			{
//	cout << "\n\n" << *nods->getElement(pC)->getPotential()->getVarSets();
				pC=nods->getNext ( pC );
			}
		}

		if ( verbosity!=NULL && verbosity->verbosityR.structure ) // better to add "cliques" in verbosity record
			cout << *undirectedBN->getJunctionTree();

	}
	/*____________________________________________________________________________________ */

	GBN::GBN()
	{
		init();
	};

	/*____________________________________________________________________________________ */

	GBN::~GBN()
	{
		zaparr ( parents );
		zaparr ( probTables );
		zap ( undirectedBN );
		zap ( discreteSample );
		zap ( clique );
	};
	/*____________________________________________________________________________________ */

	void GBN::printParents()
	{
		Attribute *attribute;
		intList::iterator p;
		int j;
		if ( parents!=NULL )
		{
			for ( int cont=0; cont < totalAttributos; cont++ ) //
// if (cont!=classPosition)
			{
				attribute=discreteSample->listOfAttributes->getElement ( cont );
				if ( attribute->isSelected() )
					cout <<"\nParents for attribute " << cont+1 << " (" << attribute->getName() << ") are: ";
				if ( !empty ( parents[cont] ) )
				{
					p=parents[cont]->getFirst();
					while ( p!=parents[cont]->end() )
					{
						j=parents[cont]->getElement ( p );
						attribute=discreteSample->listOfAttributes->getElement ( j );
						cout << j+1 << " (" << attribute->getName() <<")";
						p=parents[cont]->getNext ( p );
						if ( p!=parents[cont]->end() ) cout<<", "; else cout <<".";
					}
				}
//cout <<"\n";
			}
			cout <<"\n";
		}
		else
		{
			cout <<"Error in GBN::print()";
			exit ( 0 );
		}
	};

	/*____________________________________________________________________________________ */

	bool GBN::isSet()
	{
		return parents!=NULL && probTables!=NULL;
	};


	/*____________________________________________________________________________________ */

	void GBN::setProbTables()
	{

		probTables=new CPT*[totalAttributos];
		intList *dimensions=NULL, *vars=NULL, *conditionalVars=NULL;

		if ( verbosity!=NULL && verbosity->verbosityR.progress ) cout <<"\ncomputing prob tables";

		for ( int i=0;i<totalAttributos;i++ )
		{
			if ( verbosity!=NULL && verbosity->verbosityR.progress ) if ( i%10==0 ) cout <<"\natt " << i+1 <<" of " << totalAttributos;
			conditionalVars=NULL;
			vars=new intList();
			vars->insertElement ( i );
			if ( !empty ( parents[i] ) )
			{
				vars->copyPaste ( parents[i] );
				conditionalVars=new intList ( *parents[i] );
			}
			try
			{
				dimensions=discreteSample->listOfAttributes->getDimensionList ( vars );
//cout <<"samp is:" << *discreteSample->sample <<"\n";
//cout <<"atts are:" <<*discreteSample->listOfAttributes <<"\n";
				probTables[i]=new CPT ( discreteSample->sample, vars, conditionalVars, dimensions, bayesType, alpha, discreteSample->listOfAttributes );

				zap ( dimensions );
			}
			catch ( ORI & ori ) {ori.addMessage ( "GBN::setProbTables, it may be a missing value" ); throw;}
			zap ( conditionalVars );
			zap ( vars );
		} // end for each selected attribute
	};



	/*____________________________________________________________________________________ */

	double GBN::getAttributeProb ( int att, intList* inputPattern )
	{
//// select attributes in factor
		intList* selection=new intList();
		selection->insertElement ( att );
		if ( !empty ( parents[att] ) )
			selection->copyPaste ( parents[att] );


// filter pattern
		intList* reducedInputPattern=inputPattern->copyElementsWithPositionsIn ( selection, false ); //ordered as in argument
		int* positions=reducedInputPattern->getTable();
		Prob p=probTables[att]->getValue ( positions );//, alphaNumerator, alphaDenominator);
		zaparr ( positions );
		zap ( reducedInputPattern );
		zap ( selection );
		if ( p.isUndefined() )  return 1;
		else return p.convert();
	}
	/*____________________________________________________________________________________ */
//getClassicPosteriorProb
	/*
		double* GBN::getNewPosteriorProb ( int att, intList* inputPattern )
		{
	//it should not work for general BN, as that for the death risk calculator in SC disease
			if ( clique==NULL ) createClique();
			intList* atts=new intList();
			atts->insertElement ( att );
			Clique* clique2=new Clique ( *clique );
			clique2->getPotentialList()->removeInconsistenciesWithEvidence ( inputPattern );



			double* probs= clique2->getProbs ( atts );
		totalModalidades	zap ( atts );
			zap ( clique2 );
			return probs;
		}
		/*____________________________________________________________________________________ */


	double* GBN::getClassicPosteriorProb ( int att, intList* inputPattern )
	{
		try
		{

			Attribute* attribute, *attribute2;
			attribute=discreteSample->listOfAttributes->getElement ( att );
			int totalModalidades=attribute->getTotalModalidades(), parent_atrib, i;
			double maxProb=0, totalLogProb=0, attProb; //, totalConfigurations=1;
			int** factorToAvoidUnderflow=new int*[totalModalidades];
			double* probs=new double[totalModalidades];
			double** p=new double*[totalModalidades];
			bool isZero, under;
			InitializeList ( probs, totalModalidades, 0.0 );
			int missingAtts=0, missingDims=1, *positions=NULL;
			MultidimensionalTable<int>* missingTable=NULL;
			intList* dims=new intList(), *index=new intList();
			// missing values

			for ( int num_atrib=0; num_atrib<totalAttributos; num_atrib++ )
				if ( discreteSample->listOfAttributes->getElement ( num_atrib )->isSelected() )
					if ( num_atrib!=att )
					{
						attribute2=discreteSample->listOfAttributes->getElement ( num_atrib );
						//  totalConfigurations=totalConfigurations* attribute2->getTotalModalidades();
						if ( attribute2->isMissing ( inputPattern->getElement ( num_atrib ) ) )
						{
							missingAtts++;
							missingDims=missingDims*attribute2->getTotalModalidades();
							dims->insertElement ( attribute2->getTotalModalidades() );
							index->insertElement ( num_atrib );
						}
					}


			for ( int val=0; val< totalModalidades; val++ )
			{
				factorToAvoidUnderflow[val]=Initialize ( missingDims, 0 );
				p[val]=Initialize ( missingDims, 1.0 );
			}

			if ( missingAtts>0 ) missingTable=new MultidimensionalTable<int> ( dims );
			for ( int val=0; val< totalModalidades; val++ )
			{
				inputPattern->changeElementAtPos ( val, att );
				for ( int s=0;s<missingDims;s++ )
				{
					if ( missingDims>1 )
					{
						positions=missingTable->getPositions ( s );
						for ( int d=0;d<missingAtts;d++ )
							inputPattern->changeElementAtPos ( positions[d], index->getElement ( ( int ) d ) );
						zaparr ( positions );
					}
					for ( int num_atrib=0; num_atrib<totalAttributos; num_atrib++ )
						if ( discreteSample->listOfAttributes->getElement ( num_atrib )->isSelected() )
						{
							isZero=false;
							attProb=getAttributeProb ( num_atrib, inputPattern );
							if ( p[val][s]==0 || attProb ==0.0 ) isZero=true;
//							if (num_atrib<(totalAttributos-1)) under=p[val][s]* attProb*getAttributeProb ( num_atrib+1, inputPattern ); else
							under= ( ( std::log10 ( p[val][s] ) +std::log10 ( attProb ) )  <= std::numeric_limits< double >::min_exponent10 );
							if ( !isZero )
								while ( under ) //underflow: throw Underflow("GBN::getClassicPosteriorProb ( int att, intList* inputPattern )");
								{
									p[val][s]=p[val][s]*10;
									factorToAvoidUnderflow[val][s]++;
									//									if (num_atrib<(totalAttributos-1)) under=p[val][s]* attProb*getAttributeProb ( num_atrib+1, inputPattern ); else
									under= ( ( std::log10 ( p[val][s] ) +std::log10 ( attProb ) )  <= std::numeric_limits< double >::min_exponent10 );
//probs[val]=probs[val]*discreteSample->listOfAttributes->getElement ( num_atrib )->getTotalModalidades();
								}
							p[val][s]=p[val][s]*attProb;

						}
				}
			} // for each modality
			zap ( missingTable );
			zap ( dims );
			zap ( index );

			int max=0;
			for ( int val=0; val< totalModalidades; val++ )
				for ( int s=0;s<missingDims;s++ )
				{
					if ( factorToAvoidUnderflow[val][s]>max ) max=factorToAvoidUnderflow[val][s];
				}

			for ( int val=0; val< totalModalidades; val++ )
				for ( int s=0;s<missingDims;s++ )
				{
					for ( int i=factorToAvoidUnderflow[val][s]; i<max; i++ )
					{


						p[val][s]=p[val][s]*10;
					}
					probs[val]=probs[val]+ p[val][s];
					if ( isinf ( probs[val] ) )
						throw Overflow ( "GBN::getClassicPosteriorProb ( int att, intList* inputPattern )" );
//				cout << "|" << probs[val] <<"|";
				}
			normalize ( probs, totalModalidades );
			for ( int i=0; i<totalModalidades;i++ )
			{
				zaparr ( factorToAvoidUnderflow[i] );
				zaparr ( p[i] );
			}

			zaparr ( factorToAvoidUnderflow,  totalModalidades );
			zaparr ( p, totalModalidades );
			return probs;
		}
		catch ( BasicException& be ) {be.addMessage ( "\ncalled from double* GBN::getClassicPosteriorProb(int att, intList* inputPattern)" ); throw;};
	}




	/*____________________________________________________________________________________ */

	void GBN::createClique()
	{
		clique=new Clique();
		Node* node;
		PotentialList* potentialList=new PotentialList();
		PotentialTable* pt;
		if ( verbosity!=NULL && verbosity->verbosityR.progress ) cout <<"\ncreating a unique clique for direct inference (it will not work with missing data)";
		for ( int num_atrib=0; num_atrib<totalAttributos; num_atrib++ )
		{
			if ( verbosity!=NULL && verbosity->verbosityR.progress ) cout <<"\ninserting  att " << num_atrib+1 <<" of " << totalAttributos  <<" in the clique";
			node=new Node ( num_atrib );
			clique->insertElement ( node );
			//  zap(node);
			if ( verbosity!=NULL && verbosity->verbosityR.progress ) cout <<"\nadding potentials to clique for att " << num_atrib+1 <<" of " << totalAttributos;
			pt=probTables[num_atrib]->convertToPotential();
			potentialList->insertElement ( pt );
		}
		clique->setPotentialList ( potentialList );
		zap ( potentialList );
	}

	/*____________________________________________________________________________________ */

	void GBN::setParents()
	{
		throw NonImplemented ( "GBN::setParents()" );
		exit ( 0 );
	}


	/*____________________________________________________________________________________ */

	double* GBN::getPosteriorProb ( int att, intList* inputPattern )
	// directMethod does not use UndirectedBN
	{
		try
		{
			double* probs=NULL;

			int totalModalidades=discreteSample->listOfAttributes->getElement ( att )->getTotalModalidades();
			intList* inputCompletePattern=new intList ( *inputPattern );
//cout << "before posteriorprob\n";
			if ( directMethod )
			{
				probs=getClassicPosteriorProb ( att, inputCompletePattern );
				/*
								probs=Initialize(totalModalidades, 0.0);//getClassicPosteriorProb ( att, inputCompletePattern );
				probs[0]=1;
				for (int i=0; i<10000;i++)
				for (int j=0; j<10000; j++)
				;
				*/
			}

//probs=getNewPosteriorProb(att, inputCompletePattern);very time consuming
			else probs=getJunctionTreePosteriorProb ( att, inputCompletePattern );
//cout << "afterposter\n";
// probs=getClassicPosteriorProb(att, inputCompletePattern);
			zap ( inputCompletePattern );
// for (int i=0; i<totalModalidades; i++)
// cout << "probs at val " << i << ": " << probs[i] << "\n";
			if ( sum ( probs, totalModalidades ) !=0 )
				normalize ( probs, totalModalidades );
//cout << "afternorm\n";
			return probs;
		}
		catch ( BasicException& be ) {be.addMessage ( "\ncalled from double* GBN::getPosteriorProb(int att, intList* inputPattern))" ); throw;};
	}
	/*____________________________________________________________________________________ */

	double* GBN::getJunctionTreePosteriorProb ( int att, intList* inputPattern )
	{
		try
		{
//float alphaDenominator=getAlphaDenominator(att, inputPattern);
//float alphaNumerator=getAlphaNumerator(att, inputPattern, alphaDenominator);
			return undirectedBN->getPosteriorProb ( att, inputPattern );//, alphaNumerator, alphaDenominator);
		}
		catch ( BasicException& be ) {be.addMessage ( "\ncalled from double* GBN::getJunctionTreePosteriorProb ( int att, intList* inputPattern )" ); throw;};
	}
	/*____________________________________________________________________________________ */
	/*
	floatList* GBN::getAlphaNumerators(int att, intList* totalMarginal, float alphaDenominator)
	{
	floatList* result=new floatList();
	int totalModalidades=discreteSample->listOfAttributes->getElement(att)->GetTotalModalidades();
	double totalModalidades2;
	Prob alphaPortionNumerator;
	int totalSize=0;
	float total=0;

	if (parents[att]!=NULL && parents[att]->size()!=0)
	if (bayesType==MarginalBayes || bayesType==BDistanceMarginal)
	for (int a=0; a<totalModalidades;a++)
	totalSize=totalSize+totalMarginal->getElement(a);
	for (int a=0; a<totalModalidades;a++)
	{
	if (parents[att]==NULL || parents[att]->size()==0) alphaPortionNumerator=Prob(1, totalModalidades);
	else
	{
	totalModalidades2=discreteSample->listOfAttributes->getElement(parents[att]->getFirstElement())->GetTotalModalidades();
	if (bayesType==MarginalBayes || bayesType==BDistanceMarginal)
	{
	alphaPortionNumerator=Prob(totalMarginal->getElement(a), totalSize*totalModalidades2);
	//cout <<"att is " << att;
	//cout <<"\nden: " << alphaDenominator <<", portionNum: " << alphaPortionNumerator.print() <<"\n";
	//end();
	}
	else alphaPortionNumerator=Prob(1,totalModalidades*totalModalidades2);
	}
	result->insertElement(alphaPortionNumerator.convert()*alphaDenominator);
	total=total+alphaPortionNumerator.convert()*alphaDenominator;
	}
	/*
	if (total!=alphaDenominator)
	{
	cout <<"total is " << total <<" while totden is " << alphaDenominator <<"\n";
	}

	  }
	 /*___________________________________________________________________________________ */
	/*
	  float GBN::getAlphaDenominator(int att)
	  {
	    return alpha*getSNPDistance(att);
	  }
	  /*___________________________________________________________ */




}
;  // Fin del Namespace

#endif

/* Fin Fichero: GBN.cpp */
