#include <rumba/manifoldmatrix.h>


using RUMBA::Manifold;
using RUMBA::ManifoldMatrix;



ManifoldMatrix RUMBA::makeMatrix(Manifold<double> M, bool transpose )
{
	int r = transpose ? M.timepoints() : M.pixels();
	int c = transpose ? M.pixels() : M.timepoints();
	return ManifoldMatrix ( M,M.begin(),r,c,transpose);
}	

ManifoldMatrix RUMBA::makeMatrix(Manifold<double> M )
{
	bool t;
	if ( M.headerData().count ( "hist_description" )
			&& M.headerData()["hist_description"].asString() == "transposed matrix" )
		t = true;
	else
		t = false;
	int r = t ? M.timepoints() : M.pixels();
	int c = t ? M.pixels() : M.timepoints();
	return ManifoldMatrix ( M,M.begin(),r,c,t);
}

ManifoldMatrix RUMBA::makeMatrix(int i, int j)
{
	manifold_generator G;
	return G(i,j);
}

ManifoldMatrix RUMBA::identityMatrix(int n)
{
	Manifold<double> M ( intPoint(n,1,1,n) );
	
	memset ( M.begin(), 0, M.size() * sizeof(double) );
	for ( int i = 0; i < n; ++i )
		M.setElement ( intPoint(i,0,0,i), 1 );
	
	return makeMatrix(M);
}

ManifoldMatrix RUMBA::pseudoInvert( const ManifoldMatrix& M)
{
	return invert(M.transpose() * M) * M.transpose(); 
}

// return true if result matrix should be transposed. Here, the left point
// is the (x1,y1,z1,t) size of an untransposed ManifoldMatrix, and right is 
// the (x2,y2,z2,t) size of a transposed ManifoldMatrix. It's clear that
// the resulting matrix will have x1 * y1 * z1 rows and x2 * y2 * z2
// columns, the tricky part is deciding how to express this in a 4d manifold.
// The trick is to squash the left operand if it is 1-dimensional and the 
// right operand isn't, otherwise
// squash the right operand.

namespace {
bool tiebreaker (RUMBA::intPoint left, RUMBA::intPoint right)
{
	int ldim = 0, rdim = 0;

	if ( left.x() > 1 ) ldim++;
	if ( left.y() > 1 ) ldim++;
	if ( left.z() > 1 ) ldim++;

	if ( right.x() > 1 ) rdim++;
	if ( right.y() > 1 ) rdim++;
	if ( right.z() > 1 ) rdim++;

	if ( rdim >1 && ldim <= 1)
		return true;
	else
		return false;

}

}


// the tricky part here is to work out whether the result should be transposed
// or not
ManifoldMatrix RUMBA::manifold_generator::operator() (const ManifoldMatrix& left, const ManifoldMatrix& right)
{

	if ( left.Transpose && ! right.Transpose )
	{
		return makeMatrix
			( 
			Manifold<double> 
			( 
				intPoint(left.M.timepoints(),1,1,right.M.timepoints()) 
			),
			false
			);
	}
	else if ( left.Transpose && right.Transpose )
	{
		return makeMatrix 
		(
			 Manifold<double>
			 (
			  intPoint(right.M.width(),right.M.height(),right.M.depth(),left.M.timepoints())
			  ),
			 true
		 );
	}
	else if ( ! left.Transpose && ! right.Transpose )
	{
		return makeMatrix
			(
			 Manifold<double>
			 (
			  intPoint(left.M.width(),left.M.height(),left.M.depth(),right.M.timepoints()
				  )
			  ),
			 false
			 );
	}
	else if ( !	tiebreaker ( left.M.extent(), right.M.extent() ) )
	{
		return makeMatrix 
			(
				Manifold<double> ( 
				intPoint
				( 
					left.M.width(), 
					left.M.height(), 
					left.M.depth(), 
					right.M.pixels()
				)
				),
				false
			);
	}
	else 
	{
		return makeMatrix 
			(
			 Manifold<double> (
			 intPoint
			 (
			  right.M.width(),
			  right.M.height(),
			  right.M.depth(),
			  left.M.pixels()
			 )
			 ),
			 true
			);
	}
	
}

ManifoldMatrix RUMBA::manifold_generator::operator() (const ManifoldMatrix& my_matrix)
{
	Manifold<double> x (my_matrix.M.extent());
	return ManifoldMatrix(x,x.begin(),my_matrix.rows(),my_matrix.cols(),my_matrix.Transpose);
}

ManifoldMatrix RUMBA::manifold_generator::operator() (unsigned int r, unsigned int c)
{
	return makeMatrix(Manifold<double> ( intPoint(r,1,1,c)));
}


ManifoldMatrix RUMBA::getRow ( const ManifoldMatrix& Mm, int r )
{
	if ( r > Mm.rows() )
		throw RUMBA::Exception("in RUMBA::getRow(): argument out of range");

	Manifold<double> M ( intPoint(1,1,1,Mm.cols() ) );
	for ( int i = 0; i < Mm.cols(); ++i )
		M[ i ] = Mm.element(r,i);
	return makeMatrix(M);
}	


		
	

#ifdef TEST_MATRIX
int main()
{
	Manifold<double> M(intPoint(2,1,1,2));
	M[0] = 9;
	M[1] = -2;
	M[2] = -1;
	M[3] = 3;

	ManifoldMatrix mM (M,M.begin(),2,2,false);
	cout << "Got here" << endl;

	Manifold<double> N(intPoint(2,1,1,3));
	N[0] = 5;
	N[1] = 0;
	N[2] = 5;
	N[3] = 1;
	N[4] = 2;
	N[5] = 3;

	ManifoldMatrix mN (N,N.begin(),2,3,false);
	cout << "Got here" << endl;

	// call multiply
	ManifoldMatrix mP = multiply(mM, mN);

	// or just use operator*
	mP = mM * mN;

//	RUMBA::LU_Functor<ManifoldMatrix> x(mM);
//	cout << x.determinant() << endl;
//	std::vector<double> v;
//	v.push_back(-13); v.push_back(14);
//	x.solve(v);
//	cout << "Det: " << x.determinant() << endl;
//	cout << "Solution: " << v[0] << " " << v[1] << endl;
//	cout << mM.element(0,0) << " " <<  mM.element(0,1) << endl;
//	cout << mM.element(1,0) << " " <<  mM.element(1,1) << endl;
	ManifoldMatrix mI = invert(mM);
	cout << mI.element(0,0) << " " <<  mI.element(0,1) << endl;
	cout << mI.element(1,0) << " " <<  mI.element(1,1) << endl;


	cout << "Got here: end" << endl;
}

#endif
