#include <rumba/matrix.hpp>
#include <rumba/manifoldmatrix.h>
#include <rumba/matrixio.h>
#include <rumba/point.hpp>
#include <iostream>

#define DEBUG



#include "svd.h"

#include "jacobi.h"
#include "givens.h"
#include "householder.h"

using namespace RUMBA;

namespace
{
void destructive_svd_impl
(
 ManifoldMatrix& X,
 ManifoldMatrix& U,
 ManifoldMatrix& V,
 ManifoldMatrix& D,
 bool accumulate_U
 );
}


void RUMBA::svd
(
 const ManifoldMatrix& X, 
 ManifoldMatrix& U, 
 ManifoldMatrix& V, 
 ManifoldMatrix& D
 )
{
	ManifoldMatrix C (makeMatrix(X.M.copy()));
	destructive_svd(C,U,V,D);
}

void RUMBA::svd
(
 const ManifoldMatrix& X, 
 ManifoldMatrix& V, 
 ManifoldMatrix& D
 )
{
	ManifoldMatrix C (makeMatrix(X.M.copy()));
	destructive_svd(C,V,D);
}

void RUMBA::destructive_svd
(
 ManifoldMatrix& X, 
 ManifoldMatrix& U, 
 ManifoldMatrix& V, 
 ManifoldMatrix& D
 )
{
	destructive_svd_impl(X,U,V,D,true);
}

void RUMBA::destructive_svd
(
 ManifoldMatrix& X,
 ManifoldMatrix& V,
 ManifoldMatrix& D
 )
{
	ManifoldMatrix U(makeMatrix(Manifold<double>()));
	destructive_svd_impl(X,U,V,D,false);
}




/*---------------------- Implementation Prototypes -----------------------*/

namespace 
{

// main intermediate step in svd algortihm. Diagonalizes a bidiagonal block 
// that contains no zeroes
void svd_step(ManifoldMatrix&);

// given bidiagonal matrix A, such that entry A(i,i) == 0, zero row i
// by pre-multiplying with a sequence of Givens transforms
void flattenRow
( 
 ManifoldMatrix& diagonal, 
 ManifoldMatrix& superdiagonal, 
 unsigned int i,
 std::list<Givens> * transforms 
 );

// given bidiagonal matrix A, such that entry A(i,i) == 0, zero column i
// by post-multiplying with a sequence of Givens transforms
void flattenColumn 
( 
 ManifoldMatrix& diagonal, 
 ManifoldMatrix& superdiagonal,
 unsigned int i,
 std::list<Givens>* transforms 
 );

// helper for threshold
bool svd_check_fabs
(
 int i, 
 ManifoldMatrix& diagonal,
 ManifoldMatrix& superdiagonal
);

// wipe out small entries 
void svd_threshold
( 
 ManifoldMatrix& diagonal, 
 ManifoldMatrix& superdiagonal
 );

// re-construct the right hand svd matrix V in the formula 
// M = UDV^t
ManifoldMatrix svd_accumulate_right
( 
 const ManifoldMatrix& hh, 
 const std::list<Givens>& mylist
 );


// re-construct the left hand svd matrix U in the formula 
// M = UDV^t.  Note that the matrix applied to A in the diagonalization
// process is U^t. So to recover U, we need to apply transforms in reverse
// order.
ManifoldMatrix svd_accumulate_left
( 
 const ManifoldMatrix& hh, 
 const std::list<Givens>& mylist
 );





// similar to flattenRow, but this time, the matrices are given in the
// form of a diagonal and a bidiagonal, both which are column vectors 
// of lengths n and n-1 respectively.
void flattenRow
( 
 ManifoldMatrix& diagonal, 
 ManifoldMatrix& superdiagonal, 
 unsigned int i,
 std::list<Givens> * transforms 
 )
{
	unsigned int n = diagonal.rows();
	Givens g;

	double x = superdiagonal.element(i,0);
	superdiagonal.element(i,0) = 0;
	for ( unsigned int j = i+1; j < n; ++j )
	{
		g = get_givens2(x, diagonal.element(j,0) );
		g.setCoords(i,j);
		transforms->push_back(g);

		diagonal.element(j,0) = g.s() * x + g.c() * diagonal.element(j,0);

		if ( j < n-1 )
		{
			x = -g.s() * superdiagonal.element(j,0);
			superdiagonal.element(j,0) *= g.c();
		}
	}

}

// post-multiply by Givens transforms
void flattenColumn 
( 
 ManifoldMatrix& diagonal, 
 ManifoldMatrix& superdiagonal,
 unsigned int i,
 std::list<Givens>* transforms 
 )
{

	Givens g;


	double x = superdiagonal.element(i-1,0);
	superdiagonal.element(i-1,0) = 0;

	
	for ( int j = i-1; j >= 0; --j )
	{
		g = get_givens(diagonal.element(j,0), x );
		g.setCoords(j,i);
		transforms->push_back(g);
		diagonal.element(j,0) = g.c() * diagonal.element(j,0) - g.s() * x ;
		if ( j > 0 )
		{
			x =  g.s() * superdiagonal.element(j-1,0);
			superdiagonal.element(j-1,0) *= g.c();
		}
	}
}

bool svd_check_fabs
(
 int i, 
 ManifoldMatrix& diagonal,
 ManifoldMatrix& superdiagonal
)
{
	double epsilon = 1e-8;
	bool result = fabs(superdiagonal.element(i,0)) < 
		epsilon *
		(
				fabs(diagonal.element(i,0)) + fabs(diagonal.element(i+1,0))
		);


	return result;
}

void svd_threshold(ManifoldMatrix& diagonal, ManifoldMatrix& superdiagonal)
{
	int n = diagonal.rows();
	for ( int i = 0; i < n-1; ++i )
		if (svd_check_fabs(i,diagonal,superdiagonal))
			superdiagonal.element(i,0) = 0;

}



void svd_step(ManifoldMatrix& diagonal, ManifoldMatrix& superdiagonal,
	std::list<Givens>* transforms, std::list<Givens>* left_transforms,
	int offset)
{
	int n = diagonal.rows();
	double y = diagonal.element(0,0) * diagonal.element(0,0);
	double z = diagonal.element(0,0)*superdiagonal.element(0,0);
	double mu;
	Givens g;

	ManifoldMatrix T ( makeMatrix(Manifold<double>(intPoint(2,1,1,2))));

	T.element(0,0) = diagonal.element(n-2,0)*diagonal.element(n-2,0)
		+ superdiagonal.element(n-3,0) * superdiagonal.element(n-3,0);

	T.element(1,1) =  diagonal.element(n-1,0)*diagonal.element(n-1,0)
		+ superdiagonal.element(n-2,0) * superdiagonal.element(n-2,0);

	T.element(0,1) =  superdiagonal.element(n-2,0) * diagonal.element(n-2,0);

	T.element(1,0) = T.element(0,1);
	ManifoldMatrix S (makeMatrix(T.M.copy()));

	jacobi(S);

	if ( fabs(T.element(1,1) - S.element(0,0)) < fabs(T.element(1,1) - S.element(1,1)))
		mu = S.element(0,0);
	else
		mu = S.element(1,1);

	y -= mu;

	ManifoldMatrix M = RUMBA::identityMatrix(n);
	M.element(0,0) = diagonal.element(0,0);
	for ( int i = 1; i < n; ++i )
	{
		M.element(i,i) = diagonal.element(i,0);
		M.element(i-1,i) = superdiagonal.element(i-1,0);
	}

	for ( int k = 0; k < n-1; ++k )
	{
		g = get_givens(y,z);
		g.setCoords(k,k+1);
		apply_right_givens(M,g);
		g.adjustCoords(offset,offset);
		transforms->push_back ( g );

		y = M.element(k,k);
		z = M.element(k+1,k);

		g = get_givens(y,z);
		g.setCoords(k,k+1);
		apply_left_givens(M,g);
		g.adjustCoords(offset,offset);
		left_transforms->push_back ( g );

		if (k<n-2)
		{
			assert (k<M.rows());
			assert (k+2<M.cols());
			y = M.element(k,k+1); 
			z = M.element(k,k+2);
		}
	}

	for ( int i = 0; i < n; ++i )
		diagonal.element(i,0) = M.element(i,i);

	for ( int i = 1; i < n; ++i )
		superdiagonal.element(i-1,0) = M.element(i-1,i);
}


// re-construct the right hand svd matrix V in the formula 
// M = UDV^t
ManifoldMatrix svd_accumulate_right
( 
 const ManifoldMatrix& hh, 
 const std::list<Givens>& mylist
 )
{
	ManifoldMatrix result = RUMBA::identityMatrix(hh.cols());

	// accumulate householder computation
	for ( int i=0; i < hh.cols()-2; ++i )
	{
		ManifoldMatrix v (hh.subMatrix(i,1,i+1,hh.cols()-(i+1)).transpose());
		v.element(0,0) = 1;
		apply_hh_right (result,v);
	}

	multiply(result,mylist);

	return result;
}


// re-construct the left hand svd matrix U in the formula 
// M = UDV^t.  Note that the matrix applied to A in the diagonalization
// process is U^t. So to recover U, we need to apply transforms in reverse
// order.
ManifoldMatrix svd_accumulate_left
( 
 const ManifoldMatrix& hh, 
 const std::list<Givens>& mylist
 )
{
	ManifoldMatrix result = 
		RUMBA::makeMatrix(Manifold<double>(intPoint(hh.rows(),1,1,hh.cols())));

	for ( int i =0; i < result.rows(); ++i )
		for (int j = 0; j < result.cols(); ++j )
			if ( i == j ) result.element(i,j) = 1; 
			else result.element(i,j) = 0;

	multiply(mylist,result,true);

	
	// accumulate householder computation
	for ( int i = hh.cols() - 1; i>= 0; --i  )
	{
		ManifoldMatrix v (hh.subMatrix(i, hh.rows()-(i), i, 1));
		v.element(0,0) = 1;
		apply_hh_left(result,v);
	}

	return result;
}







void destructive_svd_impl
(
 ManifoldMatrix& X,
 ManifoldMatrix& U,
 ManifoldMatrix& V,
 ManifoldMatrix& D,
 bool accumulate_U
 )
{
	int p,q;
	const int n = X.cols();

	std::list<Givens> left_givens_list;
	std::list<Givens> right_givens_list;

	// get B;
	//
	houseHolderBiDiagonalize(X);

	ManifoldMatrix diagonal=
		makeMatrix(Manifold<double>(intPoint(n,1,1,1)));

	ManifoldMatrix superdiagonal=
		makeMatrix(Manifold<double>(intPoint(n-1,1,1,1)));

	diagonal.element(0,0) = X.element(0,0);

	for ( int i = 1; i < n; ++i )
	{
		diagonal.element(i,0) = X.element(i,i);
		superdiagonal.element(i-1,0) = X.element(i-1,i);
	}	

	do 
	{
		svd_threshold(diagonal, superdiagonal);

		q = 0;
		assert(n-2-q < superdiagonal.rows());
		while ( n-2-q >= 0 && superdiagonal.element(n-2-q,0) == 0 )
			++q;


		if ( q== n-1 )
			break;
		
		p = n - q;

		assert (p-2 < superdiagonal.rows());	
		while ( p>1 && superdiagonal.element(p-2,0) !=0  )
			--p;

		if ( q < n )
		{

			bool flag = false;
			for (int i = p-1; i<n-q-1; ++i )
			{
				if ( fabs(diagonal.element(i,0))< 1e-8 && i < superdiagonal.rows())
				{
//					superdiagonal.element(i,0) = 0;
					diagonal.element(i,0) = 0;
					flattenRow(diagonal,superdiagonal,i,&left_givens_list);
					if ( i > 0 )
						flattenColumn
							( 
							 diagonal,
							 superdiagonal,
							 i,
							 &right_givens_list
							 );

					flag = true;

				}
			}

			if ( 
				fabs(diagonal.element(n-1,0)) < 1e-8 && 
				fabs(superdiagonal.element(n-2,0)) != 0 )
			{
				diagonal.element(n-1,0) = 0;
				flattenColumn(diagonal,superdiagonal,n-1,&right_givens_list);
				flag = true;

			}

			if (!flag)
			{
				assert ( diagonal.rows() >= (p-1) + (n-p-q+1));
				ManifoldMatrix tmp_diag = diagonal.subMatrix(p-1,n-p-q+1,0,1);
				ManifoldMatrix tmp_superdiag = superdiagonal.subMatrix(p-1,n-p-q,0,1);
				svd_step(tmp_diag, tmp_superdiag, &right_givens_list, &left_givens_list,p-1);

				diagonal.put(p-1,0,tmp_diag);
				superdiagonal.put(p-1,0,tmp_superdiag);
				// B = diag ( I_p U I_{q+m-n} )
			}
		}
	} while ( q != n );

	if ( accumulate_U )
		U = svd_accumulate_left( X,left_givens_list );
	V = svd_accumulate_right( X,right_givens_list );
	D = RUMBA::identityMatrix(V.rows());

	for ( int i = 0; i < D.rows(); ++i )
		D.element(i,i) *= diagonal.element(i,0);


}

/*

*/




}  /* namespace */
