// Copyright (C) 2006-2009 Kent-Andre Mardal and Simula Research Laboratory.
// Licensed under the GNU GPL Version 2, or (at your option) any later version.

#include "Nedelec.h"
#include <fstream>
#include "tools.h"

using std::cout;
using std::endl;
using std::string;
using GiNaC::exmap;

namespace SyFi
{

	Nedelec:: Nedelec() : StandardFE()
	{
		description = "Nedelec";
	}

	Nedelec:: Nedelec(Polygon& p, int order) : StandardFE(p, order)
	{
		compute_basis_functions();
	}

	void Nedelec:: compute_basis_functions()
	{

		// remove previously computed basis functions and dofs
		Ns.clear();
		dofs.clear();

		if ( order < 0 )
		{
			throw(std::logic_error("Nedelec elements must be of order 0 or higher."));
		}

		if ( p == NULL )
		{
			throw(std::logic_error("You need to set a polygon before the basisfunctions can be computed"));
		}

		if ( p->str().find("Line") != string::npos )
		{

			cout <<"Can not define the Nedelec element on a line"<<endl;

		}
		else if ( p->str().find("Triangle") != string::npos )
		{

			description = istr("Nedelec_", order) + "2D";

			int k = order;
			int removed_dofs=0;

			Triangle& triangle = (Triangle&)(*p);
			GiNaC::lst equations;
			GiNaC::lst variables;

			// create r
			GiNaC::ex R_k = homogenous_polv(2,k+1, 2, "a");
			GiNaC::ex R_k_x = R_k.op(0).op(0);
			GiNaC::ex R_k_y = R_k.op(0).op(1);

			// Equations that make sure that r*x = 0
			GiNaC::ex rx = (R_k_x*x + R_k_y*y).expand();
			exmap pol_map = pol2basisandcoeff(rx);
			exmap::iterator iter;
			for (iter = pol_map.begin(); iter != pol_map.end(); iter++)
			{
				if ((*iter).second != 0 )
				{
					equations.append((*iter).second == 0 );
					removed_dofs++;
				}
			}

			// create p
			GiNaC::ex P_k = bernsteinv(2,k, triangle, "b");
			GiNaC::ex P_k_x = P_k.op(0).op(0);
			GiNaC::ex P_k_y = P_k.op(0).op(1);

			// collect the variables of r and p  in one list
			variables = collapse(GiNaC::lst(collapse(GiNaC::ex_to<GiNaC::lst>(R_k.op(1))),
				collapse(GiNaC::ex_to<GiNaC::lst>(P_k.op(1)))));

			// create the polynomial space
			GiNaC::lst pspace = GiNaC::lst( R_k_x + P_k_x,
				R_k_y + P_k_y);

			int counter = 0;
			GiNaC::symbol t("t");
			GiNaC::ex dofi;
			// dofs related to edges
			for (int i=0; i< 3; i++)
			{
				Line line = triangle.line(i);
				GiNaC::lst tangent_vec = tangent(triangle, i);
				GiNaC::ex bernstein_pol = bernstein(order, line, istr("a",i));
				GiNaC::ex basis_space = bernstein_pol.op(2);
				GiNaC::ex pspace_t = inner(pspace, tangent_vec);

				GiNaC::ex basis;
				for (unsigned int j=0; j< basis_space.nops(); j++)
				{
					counter++;
					basis = basis_space.op(j);
					GiNaC::ex integrand = pspace_t*basis;
					dofi =  line.integrate(integrand);
					GiNaC::ex eq = dofi == GiNaC::numeric(0);
					equations.append(eq);

					GiNaC::lst d = GiNaC::lst(GiNaC::lst(line.vertex(0), line.vertex(1)),j);
					dofs.insert(dofs.end(), d);

				}
			}

			// dofs related to the whole triangle
			GiNaC::lst bernstein_polv;
			if ( order > 0)
			{
				counter++;
				bernstein_polv = bernsteinv(2,order-1, triangle, "a");
				GiNaC::ex basis_space = bernstein_polv.op(2);
				for (unsigned int i=0; i< basis_space.nops(); i++)
				{
					GiNaC::lst basis = GiNaC::ex_to<GiNaC::lst>(basis_space.op(i));
					GiNaC::ex integrand = inner(pspace, basis);
					dofi = triangle.integrate(integrand);
					GiNaC::ex eq = dofi == GiNaC::numeric(0);
					equations.append(eq);

					GiNaC::lst d = GiNaC::lst(GiNaC::lst(triangle.vertex(0), triangle.vertex(1), triangle.vertex(2)), i);
					dofs.insert(dofs.end(), d);
				}
			}

			// invert the matrix:
			// GiNaC has a bit strange way to invert a matrix.
			// It solves the system AA^{-1} = Id.
			// It seems that this way is the only way to do
			// properly with the solve_algo::gauss flag.
			GiNaC::matrix b; GiNaC::matrix A;
			matrix_from_equations(equations, variables, A, b);

			unsigned int ncols = A.cols();
			GiNaC::matrix vars_sq(ncols, ncols);

			// matrix of symbols
			for (unsigned r=0; r<ncols; ++r)
				for (unsigned c=0; c<ncols; ++c)
					vars_sq(r, c) = GiNaC::symbol();

			GiNaC::matrix id(ncols, ncols);

			// identity
			const GiNaC::ex _ex1(1);
			for (unsigned i=0; i<ncols; ++i)
				id(i, i) = _ex1;

			// invert the matrix
			GiNaC::matrix m_inv(ncols, ncols);
			m_inv = A.solve(vars_sq, id, GiNaC::solve_algo::gauss);

			for (unsigned int i=0; i<dofs.size(); i++)
			{
				b.let_op(removed_dofs + i) = GiNaC::numeric(1);
				GiNaC::ex xx = m_inv.mul(GiNaC::ex_to<GiNaC::matrix>(b));

				GiNaC::lst subs;
				for (unsigned int ii=0; ii<xx.nops(); ii++)
				{
					subs.append(variables.op(ii) == xx.op(ii));
				}
				GiNaC::ex Nj1 = pspace.op(0).subs(subs);
				GiNaC::ex Nj2 = pspace.op(1).subs(subs);
				Ns.insert(Ns.end(), GiNaC::matrix(2,1,GiNaC::lst(Nj1,Nj2)));
				b.let_op(removed_dofs + i) = GiNaC::numeric(0);
			}

		}
		else if ( p->str().find("Tetrahedron") != string::npos )
		{

			description = istr("Nedelec_", order) + "3D";

			int k = order;
			int removed_dofs=0;

			Tetrahedron& tetrahedron= (Tetrahedron&)(*p);
			GiNaC::lst equations;
			GiNaC::lst variables;

			// create r
			GiNaC::ex R_k = homogenous_polv(3,k+1, 3, "a");
			GiNaC::ex R_k_x = R_k.op(0).op(0);
			GiNaC::ex R_k_y = R_k.op(0).op(1);
			GiNaC::ex R_k_z = R_k.op(0).op(2);

			// Equations that make sure that r*x = 0
			GiNaC::ex rx = (R_k_x*x + R_k_y*y + R_k_z*z).expand();
			exmap pol_map = pol2basisandcoeff(rx);
			exmap::iterator iter;
			for (iter = pol_map.begin(); iter != pol_map.end(); iter++)
			{
				if ((*iter).second != 0 )
				{
					equations.append((*iter).second == 0 );
					removed_dofs++;
				}
			}

			// create p
			GiNaC::ex P_k = bernsteinv(3,k, tetrahedron, "b");
			GiNaC::ex P_k_x = P_k.op(0).op(0);
			GiNaC::ex P_k_y = P_k.op(0).op(1);
			GiNaC::ex P_k_z = P_k.op(0).op(2);

			// collect the variables of r and p  in one list
			variables = collapse(GiNaC::lst(collapse(GiNaC::ex_to<GiNaC::lst>(R_k.op(1))),
				collapse(GiNaC::ex_to<GiNaC::lst>(P_k.op(1)))));

			// create the polynomial space
			GiNaC::lst pspace = GiNaC::lst( R_k_x + P_k_x,
				R_k_y + P_k_y,
				R_k_z + P_k_z);

			int counter = 0;
			GiNaC::symbol t("t");
			GiNaC::ex dofi;

			// dofs related to edges
			for (int i=0; i< 6; i++)
			{
				Line line = tetrahedron.line(i);
				GiNaC::ex line_repr = line.repr(t);
				GiNaC::lst tangent_vec = GiNaC::lst(line_repr.op(0).rhs().coeff(t,1),
					line_repr.op(1).rhs().coeff(t,1),
					line_repr.op(2).rhs().coeff(t,1));

				GiNaC::ex bernstein_pol = bernstein(order, line, istr("a",i));
				GiNaC::ex basis_space = bernstein_pol.op(2);
				GiNaC::ex pspace_t = inner(pspace, tangent_vec);

				GiNaC::ex basis;
				for (unsigned int j=0; j< basis_space.nops(); j++)
				{
					counter++;
					basis = basis_space.op(j);
					GiNaC::ex integrand = pspace_t*basis;
					dofi =  line.integrate(integrand);
					GiNaC::ex eq = dofi == GiNaC::numeric(0);
					equations.append(eq);

					GiNaC::lst d = GiNaC::lst(GiNaC::lst(line.vertex(0), line.vertex(1)),j);
					dofs.insert(dofs.end(), d);

				}
			}

			// dofs related to faces
			if ( order > 0 )
			{
				for (int i=0; i< 4; i++)
				{
					Triangle triangle = tetrahedron.triangle(i);
					GiNaC::ex bernstein_pol = bernsteinv(3,order-1, triangle, istr("b", i));
					GiNaC::ex basis_space = bernstein_pol.op(2);

					GiNaC::ex basis;
					GiNaC::lst normal_vec = normal(tetrahedron, i);
					GiNaC::ex pspace_n = cross(pspace, normal_vec);
					if ( normal_vec.op(0) != GiNaC::numeric(0) &&
						normal_vec.op(1) != GiNaC::numeric(0) &&
						normal_vec.op(2) != GiNaC::numeric(0) )
					{
						for (unsigned int j=0; j<basis_space.nops(); j++)
						{
							basis = basis_space.op(j);
							if ( basis.op(0) != 0 || basis.op(1) != 0 )
							{
								GiNaC::ex integrand = inner(pspace_n,basis);
								if ( integrand != GiNaC::numeric(0) )
								{
									dofi = triangle.integrate(integrand);
									GiNaC::ex eq = dofi == GiNaC::numeric(0);
									equations.append(eq);
								}
							}
						}

					}
					else
					{
						for (unsigned int j=0; j<basis_space.nops(); j++)
						{
							basis = basis_space.op(j);
							GiNaC::ex integrand = inner(pspace_n,basis);
							if ( integrand != GiNaC::numeric(0) )
							{
								dofi = triangle.integrate(integrand);
								GiNaC::ex eq = dofi == GiNaC::numeric(0);
								equations.append(eq);
							}
						}
					}
				}
			}

			// dofs related to tetrahedron
			if ( order > 1 )
			{
				GiNaC::ex bernstein_pol = bernsteinv(3,order-2, tetrahedron, istr("c", 0));
				GiNaC::ex basis_space = bernstein_pol.op(2);
				GiNaC::ex basis;
				for (unsigned int j=0; j<basis_space.nops(); j++)
				{
					basis = basis_space.op(j);
					GiNaC::ex integrand = inner(pspace,basis);
					dofi = tetrahedron.integrate(integrand);
					GiNaC::ex eq = dofi == GiNaC::numeric(0);
					equations.append(eq);

					GiNaC::lst d = GiNaC::lst(GiNaC::lst(tetrahedron.vertex(0), tetrahedron.vertex(1), tetrahedron.vertex(2), tetrahedron.vertex(3)), j);
					dofs.insert(dofs.end(), d);

				}
			}

			// invert the matrix:
			// GiNaC has a bit strange way to invert a matrix.
			// It solves the system AA^{-1} = Id.
			// It seems that this way is the only way to do
			// properly with the solve_algo::gauss flag.
			GiNaC::matrix b; GiNaC::matrix A;
			matrix_from_equations(equations, variables, A, b);

			unsigned int ncols = A.cols();
			GiNaC::matrix vars_sq(ncols, ncols);

			// matrix of symbols
			for (unsigned r=0; r<ncols; ++r)
				for (unsigned c=0; c<ncols; ++c)
					vars_sq(r, c) = GiNaC::symbol();

			GiNaC::matrix id(ncols, ncols);

			// identity
			const GiNaC::ex _ex1(1);
			for (unsigned i=0; i<ncols; ++i)
				id(i, i) = _ex1;

			// invert the matrix
			GiNaC::matrix m_inv(ncols, ncols);
			m_inv = A.solve(vars_sq, id, GiNaC::solve_algo::gauss);

			for (unsigned int i=0; i<dofs.size(); i++)
			{
				b.let_op(removed_dofs + i) = GiNaC::numeric(1);
				GiNaC::ex xx = m_inv.mul(GiNaC::ex_to<GiNaC::matrix>(b));

				GiNaC::lst subs;
				for (unsigned int ii=0; ii<xx.nops(); ii++)
				{
					subs.append(variables.op(ii) == xx.op(ii));
				}
				GiNaC::ex Nj1 = pspace.op(0).subs(subs);
				GiNaC::ex Nj2 = pspace.op(1).subs(subs);
				GiNaC::ex Nj3 = pspace.op(2).subs(subs);
				Ns.insert(Ns.end(), GiNaC::matrix(3,1,GiNaC::lst(Nj1,Nj2,Nj3)));
				b.let_op(removed_dofs + i) = GiNaC::numeric(0);
			}

		}

	}

}								 // namespace SyFi
