from sage.all import matrix, diagonal_matrix, vector, QQ
from sage.all import prod, RDF, RealField, CDF, CartesianProduct, cputime
from sage.all import randint, ideal, singular, RealField
from sage.all import save, ComplexField, Rational, NumberFieldElement
import ws_ev

def left_eigenvectors(m, prec=80):
    """
    Returns the eigenvectors of M using a clever method suggested in
    William Stein's PhD thesis.
    
    INPUT:
        m -- matrix with rational coefficients
        prec -- The precision you want the roots to be. 
        
    OUTPUT:
        [eigenvalues, eigenvectors] -- a list of eigenvalues and the 
                                       corresponding eigenvectors.
    
    EXAMPLES:
        sage: q = matrix(QQ, 3, [1, 2, -1/2, -1, -1, 0, -1/2, 2, -2])
        sage: stickel3.left_eigenvectors(q)

        [[-1.8144605536121012356138,
          -0.092769723193949382193097 - 0.63619161566987696439867*I,
          -0.092769723193949382193097 + 0.63619161566987696439867*I],
         [[1.0000000000000000000000,
           4.1618829204835455363812,
           -2.6948447337428886015349],
          [1.0000000000000000000000,
           1.2107252064248938984760 + 0.67553782905921424189663*I,
           -0.23591096646188903256590 - 0.078692426778674554995928*I],
          [1.0000000000000000000000,
           1.2107252064248938984760 - 0.67553782905921424189663*I,
           -0.23591096646188903256590 + 0.078692426778674554995928*I]]]
        
    """
    
    eigenvectors = []
    es = m.eigenspaces()
    eigenvalues = []
    for e, v in es:
        if isinstance(e, Rational):
            eigenvectors.append(list(v.gen()))
        elif isinstance(e, NumberFieldElement):
            pols = v.basis()[0]
            dpoly = v.base_field().defining_polynomial()
            roots = [r[0] for r in dpoly.roots(ComplexField(prec))]
            eigenvalues += roots
            for r in roots:
                eigenvectors.append([p.polynomial()(r) for p in pols]) 
       
    return [eigenvalues, eigenvectors]
    
def fast_stickel(I, verbose=False):
    """
    Computes the complex variety of a zero dimensional ideal using 
    Stickelbergers algorithm. 
    
    INPUT:
        I -- an ideal
    OUTPUT:
        points -- the variety of I as a list of points
    
    EXAMPLE:
        sage: R.<x,y,z> = PolynomialRing(QQ)
        sage: I = R.ideal(x^2-2*x*z+5, x*y^2+y*z+1, 3*y^2-8*x*z)
        sage: v = stickel3.fast_stickel(I)
        
    """
    
    # I = I.radical()
    base_ring = I.ring()
    base_ring_gens = list(base_ring.gens())
    gb = I.groebner_basis(algorithm='libsingular:std')
    f = random_f(I)
    nb = I.normal_basis()        
    m_f = transform_matrix(gb, f, nb)
    # evals, evectors = left_eigenvectors(m_f)
    evals, evectors = ws_ev.left_eigenvectors(m_f)
    if verbose:
        print "Groebner basis for %s: %s" % (I, gb)
        print "Using random f: %s" % f
        print "The normal set is: %s" % nb
        print "The transform matrix is: %s"
        print "The left eigenvalues, eigenvectors are: ", evals, evectors

    points = []        
    for v in evectors:
        count = len(base_ring_gens)
        dealt_with = dict([(g, False) for g in base_ring_gens])
        point = [0 for i in range(len(base_ring_gens))]
        c = v[-1] # We are assuming the coset [1] is always the last entry
        for b in nb:
            if b.is_univariate() and b.total_degree() == 1:
                point[base_ring_gens.index(b)] = v[nb.index(b)] / c
                dealt_with[b] = True
                count = count -1
        
        while count != 0:
            cheapest = min([x[0] for x in dealt_with.iteritems() 
                                      if x[1] == False])        
            for f in gb:
                if f.lt() == cheapest:
                    t = f - f.lt()
                    v = -t(point)
                    point[base_ring_gens.index(cheapest)] = v
                    dealt_with[cheapest] = True
                    count -= 1
        points.append(point)
        
    return points
        
def stickel(I):
    """
    Computes the variety of a zero dimensional ideal.
    
    INPUT: A zero dimensional ideal I.
    OUTPUT: V(I) which is the set of points at which all polynomials f in I
    vanish. 
    
    """
    
    gb = I.groebner_basis(algorithm='singular:std')
    ns = normal_basis(I, gb)
    field = I.ring().base()
    if field == RDF or field == CDF:
        numerical = True
    else:
        numerical = False
        
    variety = []
    coords = []
    for v in I.ring().gens():
        m = transform_matrix(I, v, ns)
        if True:
            m = m.change_ring(RDF)
            evals = m.right_eigenvectors()[0]
        else:
            evals = m.minpoly().roots()
            # evals = [t[0] for t in m.eigenspaces()]
            evals = [t[0] for t in evals]
        coords.append(evals)
    c_prod = CartesianProduct(*coords)

    for p in c_prod:
        c = 0
        for g in I.gens():
            if numerical:
                if abs(g(*p) - 0) < 0.000001:
                    c+=1 
            elif g(*p) == 0:
                c += 1
            else:
                break
        if c == len(I.gens()):
            if p not in variety:
                variety.append(p)
    
    return variety
    
def normal_basis(I, gb):
    """
    Computes the normal set of an ideal I.
    
    INPUT: Ideal I
           Groebner basis gb
    OUTPUT: The normal set of the groebner basis. 
    
    This only works if the dimension of the ideal is zero!
    """
    
    def _singular_kbase(gb):
        l = [t.sage_poly(I.ring()) for t in list(singular.kbase(ideal(gb)))]
        return l
    
    return _singular_kbase(gb)
    
    corners = []
    lt = [g.lt() for g in gb]
    corners = [t for t in lt if t.is_univariate()]
    top_right = prod(corners)
    monomial_divisors = I.ring().monomial_all_divisors(top_right)
    normal_basis = []
    for divisor in monomial_divisors:
        if divisor == divisor.reduce(gb):
            normal_basis.append(divisor)
    f = sum(normal_basis) + 1
    
    return [m for m in f.monomials()]

def transform_matrix(gb, f, normal_basis):
    """
    Computes the transform matrix of a normal basis w.r.t to a Groebner Basis
    and a function f.
    
    INPUT:
        gb -- Groebner basis
        f -- a function
        nb -- normal basis, i.e. R[x]/I
    
    OUTPUT:
        a coefficient matrix m_f
        
    EXAMPLE:
        sage: R.<x,y,z> = PolynomialRing(QQ)
        sage: I = R.ideal(x^2-2*x*z+5, x*y^2+y*z+1, 3*y^2-8*x*z)
        sage: gb = I.groebner_basis()
        sage: nb = I.normal_basis()
        sage: stickel3.transform_matrix(gb, x, nb)
        
        [    0    -1     0     0    -2     0     0     0]
        [-3/16 -3/20  -3/8     0 -3/10     0     0     0]
        [    0     0     0     1     0     0     2     0]
        [  5/2     0     0     0     0     0     0     0]
        [    0  3/40     0     0  3/20     1     0     0]
        [    0     0     0     0    -5     0     0     0]
        [    0     0     0     0     0     0     0     1]
        [-3/16     0  -3/8     0     0     0    -5     0]
        
    """
    
    tmatrix = matrix(gb[0].base_ring(), len(normal_basis))
    for i, mon in enumerate(normal_basis):
        vector = [0 for j in range(0, len(normal_basis))]
        poly = f * mon
        poly = poly.reduce(gb)
        for m in normal_basis:
            if m in poly.monomials():
                coeff = poly.monomial_coefficient(m)
                vector[normal_basis.index(m)] = coeff
        tmatrix.set_column(i, vector) 
        
    return tmatrix   

def random_f(I):
    return sum([randint(1,10**3)*f for f in I.ring().gens()]) 
