#from sage.all import *

import pdb

def compare_by_degree(f,g):
    if f.total_degree() > g.total_degree():
        return 1
    elif f.total_degree() < g.total_degree():
        return -1
    else:
        return cmp(f.lm(),g.lm())

class F5(SageObject):
    def __init__(self, F=None):
        if F is not None:
            self.Rules = range(len(F))
            self.L = range(0,1)

    def poly(self, i):
        return self.L[i][2]

    def sig(self, i):
        return self.L[i][1]

    def sugar(self, i):
        return self.L[i][0]

    def basis(self, F):
        poly = self.poly
        incremental_basis = self.incremental_basis

        self.__init__(F)

        Rules = self.Rules
        L = self.L

        m = len(F)-1
        F = sorted(F, cmp=compare_by_degree, reverse=False)
        
        f0 = F[0]
        L[0] = (f0.degree(), Signature(1, 0), f0*f0.lc()**(-1))
        Rules[0] = list()

        Gprev = set([0])
        B = [0]

        for i in (xrange(1,m+1)):
            Gcurr = incremental_basis(F[i], i, B, Gprev)
            if any(poly(lambd) == 1 for lambd in Gcurr):
                return set([1])
            Gprev = Gcurr
            #B = set([poly(l) for l in Gprev])
            B = [l for l in Gprev]
            #print B
        return set([poly(each) for each in B])

    def incremental_basis(self, f, i, B, Gprev):
        L = self.L
        critical_pair = self.critical_pair
        compute_spols = self.compute_spols
        reduction = self.reduction
        Rules = self.Rules

        L.append( (f.degree(), Signature(1,i), f*f.lc()**(-1)) )
        #print B
        curr_idx = len(L) - 1
        Gcurr = Gprev.union([curr_idx])
        Rules[i] = list()

        P = reduce(lambda x,y: x.union(y), [critical_pair(curr_idx, j, i, Gprev) for j in Gprev], set())
        while len(P) != 0:
            D = min(d for (d,t,k,u,l,v) in P)
            print "D", D
            Pd = [(d,t,k,u,l,v) for (d,t,k,u,l,v) in P if d == D]
            P = P.difference(Pd)
            S = compute_spols(Pd)
            R = reduction(S, B, Gprev, Gcurr)
            for k in R:
                P = reduce(lambda x,y: x.union(y), [critical_pair(j, k, i, Gprev) for j in Gcurr], P)
                Gcurr.add(k)
        return Gcurr

    def critical_pair(self, k, l, i, Gprev):
        poly = self.poly
        sig = self.sig
        sugar = self.sugar
        is_top_reducible = self.is_top_reducible
        is_rewritable = self.is_rewritable

        #print "crit_pair(%s,%s,%s,%s)"%(k, l, i, Gprev)
        #print self.L
        tk = poly(k).lt()
        tl = poly(l).lt()
        t = lcm(tk, tl)
        u0 = t//tk
        u1 = t//tl
        m0, e0 = sig(k)
        m1, e1 = sig(l)
        if e0 == e1 and u0*m0 == u1*m1:
            return set()
        #print "test1", e0, i, u0, m0
        #print "test2", e1, i, u1, m1
        if e0 == i and is_top_reducible(u0*m0, Gprev):
            #print "test1 done"
            return set()
        if e1 == i and is_top_reducible(u1*m1, Gprev):
            #print "test2 done"
            return set()
        if is_rewritable(u0, k) or is_rewritable(u1, l):
            #print "test3", is_rewritable(u0, k), is_rewritable(u1, l)
            return set()
        if u0 * sig(k) < u1 * sig(l):
            u0, u1 = u1, u0
            k, l = l, k
        d = max(sugar(k) + u0.degree(), sugar(l) + u1.degree())
        return set([(d,t,k,u0,l,u1)])
        
    def compute_spols(self, P):
        poly = self.poly
        sig = self.sig
        spol = self.spol
        is_rewritable = self.is_rewritable
        add_rule = self.add_rule

        L = self.L

        S = list()
        P = sorted(P, key=lambda x: x[1])
        for (d,t,k,u,l,v) in P:
            if not is_rewritable(u,k) and not is_rewritable(v,l):
                s = spol(poly(k), poly(l))
                if s != 0:
                    L.append( (d, u * sig(k), s) )
                    add_rule(u * sig(k), len(L)-1)
                    S.append(len(L)-1)
        S = sorted(S, key=lambda x: sig(x))
        return S

    def spol(self, f, g):
        LCM = lambda f,g: f.parent().monomial_lcm(f,g)
        LM = lambda f: f.lm()
        LT = lambda f: f.lt()
        return LCM(LM(f),LM(g)) // LT(f) * f - LCM(LM(f),LM(g)) // LT(g) * g

    def reduction(self, S, B, Gprev, Gcurr):
        L = self.L
        sig = self.sig
        poly = self.poly
        top_reduction = self.top_reduction
        sugar_reduce = self.sugar_reduce
        sugar = self.sugar

        to_do = S
        completed = set()
        #reducers = [poly(k) for k in B]
        while len(to_do):
            k, to_do = to_do[0], to_do[1:]
            d,h = sugar_reduce(k,B)
            #h = poly(k).reduce(reducers)
            #print k,h
            L[k] = (d, sig(k), h)
            #L[k] = (sugar(k), sig(k), h)
            newly_completed, redo = top_reduction(k, Gprev, Gcurr.union(completed))
            completed = completed.union( newly_completed )
            for j in redo:
                # insert j in to_do, sorted by increasing signature
                to_do.append(j)
                to_do.sort(key=lambda x: sig(x))
        return completed

    def sugar_reduce(self, k, B):
			sugar = self.sugar
			poly = self.poly
			d = sugar(k)
			p = poly(k)
			r = 0
			Bu = [poly(each).lm() for each in B]
			while (p != 0):
				t = p.lm()
				a = p.lc()
				i = 0
				reduced = False
				while ((i < len(Bu)) and not reduced):
					if Bu[i].divides(t):
						d = max(d, sugar(B[i]) + (Bu[i]/t).numerator().degree())
						b = poly(B[i]).lc()
						p = p - ((a*t)/(b*Bu[i])).numerator()*poly(B[i])
						reduced = True
					else:
						i = i + 1
				if (not reduced):
					r = r + a*t
					p = p - a*t
			return d,r

    def top_reduction(self, k, Gprev, Gcurr):
        find_reductor = self.find_reductor
        add_rule = self.add_rule
        poly = self.poly
        sig = self.sig
        sugar = self.sugar
        L = self.L

        if poly(k) == 0:
            verbose("reduction to zero.",level=0)
            return set(),set()
        p = poly(k)
        J = find_reductor(k, Gprev, Gcurr)
        if J == set():
            L[k] = ( sugar(k), sig(k), p * p.lc()**(-1) )
            return set([k]),set()
        j = J.pop()
        q = poly(j)
        u = p.lt()//q.lt()
        p = p - u*q
        if p != 0:
            p = p * p.lc()**(-1)
        d = max(sugar(k), sugar(j) + u.degree())
        if u * sig(j) < sig(k):
            L[k] = (d, sig(k), p)
            return set(), set([k])
        else:
            L.append((d, u * sig(j), p))
            add_rule(u * sig(j), len(L)-1)
            return set(), set([k, len(L)-1])

    def find_reductor(self, k, Gprev, Gcurr):
        is_rewritable = self.is_rewritable
        is_top_reducible = self.is_top_reducible
        poly = self.poly
        sig = self.sig
        t = poly(k).lt()
        for j in Gcurr:
            tprime = poly(j).lt()
            if tprime.divides(t):
                u = t // tprime
                mj, ej = sig(j)
                if u * sig(j) != sig(k) and not is_rewritable(u, j) \
                        and not is_top_reducible(u*mj, Gprev):
                    return set([j])
        return set()
                
    def add_rule(self, s, k):
        self.Rules[s[1]].append( (s[0],k) )

    def is_rewritable(self, u, k):
        j = self.find_rewriting(u, k)
        return j != k

    def find_rewriting(self, u, k):
        Rules = self.Rules
        mk, v = self.sig(k)
        for ctr in reversed(xrange(len(Rules[v]))):
            mj, j = Rules[v][ctr]
            if mj.divides(u * mk):
                return j
        return k

    def is_top_reducible(self, t, l):
        poly = self.poly
        for g in l:
            if poly(g).lt().divides(t):
                return True
        return False

from UserList import UserList

class Signature(UserList):
     def __init__(self, monomial, index):
         UserList.__init__(self,[monomial, index])

     def __lt__(self, other):
         if self[1] < other[1]:
             return True
         elif self[1] > other[1]:
             return False
         else:
             return (self[0] < other[0])

     def __eq__(self, other):
         return self[0] == other[0] and self[1] == other[1]

     def __neq__(self, other):
         return self[0] != other[0] or self[1] != other[1]

     def __rmul__(self, other):
         if isinstance(self, Signature):
             return Signature(other * self[0], self[1])
         else:
             raise TypeError
