import math
import numpy as num
from numpy.linalg import solve as solveLinEq
# from LinearAlgebra import solve_linear_equations as solveLinEq


# ----
class DerivVar:
    """This module provides automatic differentiation for functions with any number of variables."""

    def __init__(self, value, index=0, order=1):
        if order > 1:
            raise ValueError, 'Only first-order derivatives'
        self.value = value
        if order == 0:
            self.deriv = []
        elif type(index) == type([]):
            self.deriv = index
        else:
            self.deriv = index*[0] + [1]

    def __getitem__(self, item):
        if item < 0 or item > 1:
            raise ValueError, 'Index out of range'
        if item == 0:
            return self.value
        else:
            return self.deriv

    def __coerce__(self, other):
        if isDerivVar(other):
            return self, other
        else:
            return self, DerivVar(other, [])

    def __cmp__(self, other):
        return cmp(self.value, other.value)

    def __add__(self, other):
        return DerivVar(self.value + other.value, _mapderiv(lambda a,b: a+b, self.deriv, other.deriv))
    __radd__ = __add__

    def __sub__(self, other):
        return DerivVar(self.value - other.value, _mapderiv(lambda a,b: a-b, self.deriv, other.deriv))

    def __mul__(self, other):
        return DerivVar(self.value*other.value,
            _mapderiv(lambda a,b: a+b,
                map(lambda x,f=other.value:f*x, self.deriv),
                map(lambda x,f=self.value:f*x, other.deriv)))
    __rmul__ = __mul__

    def __div__(self, other):
        if not other.value:
            raise ZeroDivisionError, 'DerivVar division'
        inv = 1./other.value
        return DerivVar(self.value*inv,
            _mapderiv(lambda a,b: a-b,
                map(lambda x,f=inv: f*x, self.deriv),
                map(lambda x,f=self.value*inv*inv: f*x,
                    other.deriv)))

    def __rdiv__(self, other):
        return other/self

    def __pow__(self, other, z=None):
        if z is not None:
            raise TypeError, 'DerivVar does not support ternary pow()'
        val1 = pow(self.value, other.value-1)
        val = val1*self.value
        deriv1 = map(lambda x,f=val1*other.value: f*x, self.deriv)
        if isDerivVar(other) and len(other.deriv) > 0:
            deriv2 = map(lambda x, f=val*num.log(self.value): f*x,
                             other.deriv)
            return DerivVar(val,_mapderiv(lambda a,b: a+b, deriv1, deriv2))
        else:
            return DerivVar(val,deriv1)

    def __rpow__(self, other):
        return pow(other, self)
# ----


# ----
def isDerivVar(x):
    """Returns 1 if |x| is a DerivVar object."""
    return hasattr(x,'value') and hasattr(x,'deriv')
# ----


# ----
def leastSquaresFit(model, parameters, data, max_iterations=None, stopping_limit = 0.005):
    """General non-linear least-squares fit using the
    Levenberg-Marquardt algorithm and automatic derivatives."""

    n_param = len(parameters)
    p = ()
    i = 0
    for param in parameters:
        p = p + (DerivVar(param, i),)
        i = i + 1
    id = num.identity(n_param)
    l = 0.001
    chi_sq, alpha = _chiSquare(model, p, data)
    niter = 0
    while 1:
        delta = solveLinEq(alpha+l*num.diagonal(alpha)*id,-0.5*num.array(chi_sq[1]))
        next_p = map(lambda a,b: a+b, p, delta)
        next_chi_sq, next_alpha = _chiSquare(model, next_p, data)
        if next_chi_sq > chi_sq:
            l = 10.*l
        else:
            l = 0.1*l
            if chi_sq[0] - next_chi_sq[0] < stopping_limit: break
            p = next_p
            chi_sq = next_chi_sq
            alpha = next_alpha
        niter = niter + 1
        if max_iterations is not None and niter == max_iterations:
            pass
    return map(lambda p: p[0], next_p), next_chi_sq[0]
# ----


# ----
def _chiSquare(model, parameters, data):
    """ Count Chi-square. """

    n_param = len(parameters)
    chi_sq = 0.
    alpha = num.zeros((n_param, n_param))
    for point in data:
        sigma = 1
        if len(point) == 3:
            sigma = point[2]
        f = model(parameters, point[0])
        chi_sq = chi_sq + ((f-point[1])/sigma)**2
        d = num.array(f[1])/sigma
        alpha = alpha + d[:,num.newaxis]*d
    return chi_sq, alpha
# ----


# ----
def _mapderiv(func, a, b):
    """ Map a binary function on two first derivative lists. """

    nvars = max(len(a), len(b))
    a = a + (nvars-len(a))*[0]
    b = b + (nvars-len(b))*[0]
    return map(func, a, b)
# ----
