
/* hami : worked example for checking if a function is a solution
   to the Hamiltonian eigenvalue equation H*psi = E*psi.

   It does this by allowing you to supply a potential and a function,
   and then returning E(x) which then is a constant with respect
   to x (hopefully).

   TODO:
   - Try it out on the hydrogen atom case.
   - (a+b+c)*d cannot be reduced this way still
   - Complex(a,b) can not be reduced.
   - Complex*vector does it the wrong way.
   */

Use("complex"); /* Use complex, because we want to use I effectively */


SecondDeriv(dimension,vector,function,mass):=
[
  Local(result,i);
  result:=0;
  For(i:=1,i<=dimension,i++)
  [
    result:=result-(Deriv(vector[i])Deriv(vector[i])function)*(h^2)/(2*mass);
  ];
  result;
];

/* Vectors sets a variable to point to a vector. Example:
   after invoking Vectors(3,r), r should evaluate to {r1,r2,r3}.
 */


Vectors(dimension,vectors):=
[
  ForEach(vvecv,vectors)
  [
    MacroSet(vvecv,MakeVector(vvecv,dimension));
  ];
  True;
];
HoldArg("Vectors",vectors);

/* Just fill in the standard hamiltonian, H=p^2/(2*m) + V for all
   particles under consideration.
 */
Hamiltonian(dimension,vectors,masses,potential,function):=
[
  Local(result,i);
  result:=0;

  For(i:=1,i<=Length(vectors),i++)
  [
    result:=result+SecondDeriv(dimension,vectors[i],function,masses[i]);
  ];
  result+potential*function;
];
Bodied("Hamiltonian");

/* In : inproduct between two vectors referred to by variables */
In(v1,v2,dimension):=
[
  Local(i,result);
  result:=0;
  For(i:=1,i<=dimension,i++)
  [
    result:=result+v1[i]*v2[i];
  ];
  result;
];


/* This part of the module tries to implement a set of functions
   that allow a function f(x) to be expressed as g(x)*h(x) where
   f(x) and h(x) are given. The hope being that g(x) actually
   does not depend on x. Obviously g(x)=Limit f(x)/h(x).

   The general form is expected to be term1+term2+term3+....
   So we need to factor out h(x) from each term, and then reassemble
   the result.

   Each term then consists of factor1*factor2*factor3*...

1. substitute to internal format:
     x * y -> x InMul y
     x / y -> x InMul (1/y)
     x + y -> x InPlus y
     x - y -> x InPlus (-y)
   
2. Flatten for :+:

3. Each term in the flattened :+: list: flatten :*:

4. Find the factors of interest, and remove them from the list.

5. Contract back to an expression.

 */


IsModify(f):=
[
  Local(t);
  t:=Type(f);
  ( t = "+" Or
    t = "-" Or
    t = "*" Or
    t = "/");
];

IsUnModify(f):=
[
  Local(t);
  t:=Type(f);
  ( t = "InPlus" Or t = "InMul" );
];

RuleBase("InPlus",{x,y});
RuleBase("InMul",{x,y});

RuleBase("DoModify",{f});
Rule("DoModify",1,1,Type(f) = "+")
    InPlus(Substitute(f[1]),Substitute(f[2]));
Rule("DoModify",1,1,Type(f) = "-" And NrArgs(f) = 2)
    InPlus(Substitute(f[1]),Substitute(-(f[2])));
Rule("DoModify",1,1,Type(f) = "-" And NrArgs(f) = 1)
    InMul(-1,Substitute(f[1]));
Rule("DoModify",1,1,Type(f) = "*")
    InMul(Substitute(f[1]),Substitute(f[2]));
Rule("DoModify",1,1,Type(f) = "/")
    InMul(Substitute(f[1]),1/(f[2]));
UnFence("DoModify",1);

RuleBase("DoUnModify",{f});
Rule("DoUnModify",1,1,Type(f) = "InPlus")
    Substitute(f[1]) + Substitute(f[2]);
Rule("DoUnModify",1,1,Type(f) = "InMul")
    Substitute(f[1]) * Substitute(f[2]);
UnFence("DoUnModify",1);

MakeModify(f):=Substitute(f,"IsModify","DoModify");
MakeUnModify(f):=Substitute(f,"IsUnModify","DoUnModify");

/*TODO this is a nice one to have next to Flatten... */
Contract(list,op,unit):=
[
  Local(result);
  result:=unit;
  ForEach(item,list)
  [
    result:=Apply("Apply",{op,{result,item}});
  ];
  result;
];


RuleBase("ExtractFactor",{list,factor});
Rule("ExtractFactor",2,1,Contains(list,factor))
[
  Local(index);
  index:=Find(list,factor);
  DestructiveDelete(list,index);
  list;
];
Rule("ExtractFactor",2,10,True)
[
  DestructiveInsert(list,1,(1/factor));
  list;
];

RuleBase("Extract",{func,factor});

Rule("Extract",2,3,Type(factor) = "*")
[
  Local(factors);
  factors:=Flatten(factor,"*");
  ForEach(item,factors)
  [
    func:=Extract(func,item);
  ];
  func;
];

Rule("Extract",2,5,True)
[
  /* Get the internal representation (InPlus/InMul) */
  Local(intern,f2);

  intern:=MakeModify(func);
  f2:=MakeModify(factor);
  /* Flatten for InPlus */
  Local(terms,termsout);
  terms:=Flatten(intern,"InPlus");
  termsout:={};
  ForEach(item,terms)
  [
    /* Flatten for InMul */
    Local(flat);
    flat:=Flatten(item,"InMul");
    Local(toadd);
    toadd:=Contract(ExtractFactor(flat,f2),"*",1);
    DestructiveAppend(termsout,toadd);
  ];
  
  /* Return back to the outside world: replace back the + and * */
  Local(extern);
  extern:=MakeUnModify(termsout);

  Contract(extern,"+",0);
];

