(**********************************************)
(*                                            *)
(*   THE LIBKIN-WONG ALGEBRA                  *)
(*                                            *)
(**********************************************)

signature algebrasig = 
  sig
    structure Common : commonsig
    structure Make : makesig
    sharing Common = Make.Type.Common
    exception Badtypealpha
    exception Badtypearith
    exception Badtypebang
    exception Badtypecomp
    exception Badtypecond
    exception Badtypeempty
    exception Badtypeflat
    exception Badtypeid
    exception Badtypeleq
    exception Badtypemap
    exception Badtypenormalize
    exception Badtypeorempty
    exception Badtypeorflat
    exception Badtypeormap
    exception Badtypeorrho
    exception Badtypeorunion
    exception Badtypepair
    exception Badtypeproj
    exception Badtyperho
    exception Badtypesng
    exception Badtypeunion
    exception Cannotapply
    val alpha : Common.co -> Common.co
    val apply : (Common.base list -> Common.base)
                -> Common.co -> Common.co
    val apply_binary : (Common.base * Common.base -> Common.base)
                       -> Common.co -> Common.co
    val apply_op2 : (Common.base * Common.base -> Common.base)
                    -> Common.co * Common.co -> Common.co
    val apply_test : (Common.base -> bool) -> Common.co -> Common.co
    val apply_unary : (Common.base -> Common.base)
                      -> Common.co -> Common.co
    val bang : Common.co -> Common.co
    val comp : ('a -> 'b) * (Common.co -> 'a) -> Common.co -> 'b
    val cond : Common.co * ('a -> 'b) * ('a -> 'b) -> 'a -> 'b
    val empty : Common.co
    val emptyorset : Common.co -> Common.co
    val emptyset : Common.co -> Common.co
    val flat : Common.co -> Common.co
    val gen : Common.co -> Common.co
    val id : Common.co -> Common.co
    val leqbase : Common.co * Common.co -> Common.co
    val leqstring : Common.co * Common.co -> Common.co
    val monus : Common.co * Common.co -> Common.co
    val mult : Common.co * Common.co -> Common.co
    val neg : Common.co -> Common.co
    val normal : Common.co -> Common.co
    val normal_form : Common.co -> Common.co
    val normal_entries : Common.co -> Common.co
    val norm : (Common.co -> Common.co) * 'a * (Common.co * 'a -> 'a) 
                * ((Common.co * bool) * 'a -> 'b) -> Common.co -> 'b
    val exists : (Common.co -> Common.co) -> Common.co -> Common.co * bool
    val normal_time : Common.co * (Common.co -> 'a) * 
                     ('a * 'a -> bool) * real -> Common.co * bool
    val orempty : Common.co
    val orflat : Common.co -> Common.co
    val orpairwith : Common.co * Common.co -> Common.co
    val orsmap : (Common.co -> Common.co) -> Common.co -> Common.co
    val orsng : Common.co -> Common.co
    val orsum : (Common.co -> Common.co) -> Common.co -> Common.co
    val orunion : Common.co * Common.co -> Common.co
    val p1 : Common.co -> Common.co
    val p2 : Common.co -> Common.co
    val pair : (Common.co -> Common.co) * (Common.co -> Common.co)
               -> Common.co -> Common.co
    val pairwith : Common.co * Common.co -> Common.co
    val plus : Common.co * Common.co -> Common.co
    val smap : (Common.co -> Common.co) -> Common.co -> Common.co
    val sng : Common.co -> Common.co
    val sum : (Common.co -> Common.co) -> Common.co -> Common.co
    val union : Common.co * Common.co -> Common.co
  end


functor ALGEBRA (Make : makesig) : algebrasig = 

struct 

structure Make = Make 
structure Common = Make.Type.Common

local open Make Make.Type Make.Dupelim Common in 

exception Badtypecomp 
      and Badtypeproj 
      and Badtypepair
      and Badtypebang
      and Badtypeid
      and Badtyperho
      and Badtypeflat
      and Badtypesng
      and Badtypeunion
      and Badtypeempty
      and Badtypemap
      and Badtypeorrho
      and Badtypeorflat
      and Badtypeorunion
      and Badtypeormap
      and Badtypeorempty
      and Badtypealpha
      and Badtypenormalize
      and Badtypearith
      and Badtypecond 
      and Badtypeleq 

(* identity *)

val empty = (Set [])

val orempty = (Orset [])

fun id(x:Common.co) = x 

(* composition *)

infix comp 
fun (f comp g)(x:Common.co) = f(g(x)) 

(* first and second projections *)
fun p1(x:Common.co) = let fun goodforp1(Prod(y,z)) = y |
                       goodforp1 _          = raise Badtypeproj
               in goodforp1(x) end 

fun p2(x:Common.co) = let fun goodforp2(Prod(y,z)) = z |
                       goodforp2 _          = raise Badtypeproj
               in goodforp2(x) end 

(* pairing of two functions *)
infix pair 
fun (f pair g)(x:Common.co) = (Prod(f(x),g(x))) 

(* bang *)
fun bang (x:Common.co) = Unit 

(* Set functions *)

(* singleton for sets and orsets *)
fun sng(x:Common.co) = (Set [x])
fun orsng(x:Common.co) = (Orset [x])

(* set  and or-set union *)

fun union((Set a),(Set b)) = let val tmpvar = unifytypes(typeof((Set a)), 
                                                         typeof((Set b)))
                                 in (Set (remdupl(a @ b))) 
                             handle Dontunify => raise Badtypeunion end |
    union _ = raise Badtypeunion 

fun orunion((Orset a),(Orset b)) = let val tmpvar = unifytypes
                                       (typeof((Orset a)),
                                        typeof((Orset b)))
                                 in (Orset (remdupl(a @ b)))
                             handle Dontunify => raise Badtypeorunion end|
    orunion _ = raise Badtypeorunion 



(* flatten *)

exception Igoofed

fun mklistofco (Set x) =  x |
    mklistofco (Orset x) =  x |
    mklistofco _         = raise Igoofed

fun orflat(x) = let fun goodfororflat(Orset y) = 
                  let fun isorset(Orset z) = true |
                          isorset _        = false
                      fun islistoforsets [] = true |
                          islistoforsets (a::z) = isorset(a)
                  in if islistoforsets(y) then 
                         (Orset (remdupl(flat_list(map mklistofco y))))
                             else raise Badtypeorunion end |
                      goodfororflat _ = raise Badtypeorunion 
              in goodfororflat(x) end  

fun flat(x) = let fun goodforflat(Set y) = 
                  let fun isset(Set z) = true |
                          isset _      = false
                      fun islistofsets [] = true |
                          islistofsets (a::z) = isset(a)
                  in if islistofsets(y) then 
                           (Set (remdupl(flat_list(map mklistofco y))))
                             else raise Badtypeunion end |
                      goodforflat _ = raise Badtypeunion 
              in goodforflat(x) end  

(* implementing constants *)

fun emptyset(Unit) = (Set []) |
    emptyset _     = raise Badtypeempty 

fun emptyorset (Unit) = (Orset []) |
    emptyorset _      = raise Badtypeempty 


(* rhotwo function *)

fun pairwith(x,y) = let fun goodforpair(a, (Set b)) = 
                            if isnil(b) then (Set []) 
                            else (Set (map (fn z => Prod(a,z)) b)) |
                            goodforpair _ = raise Badtyperho                   
                    in goodforpair(x,y) end 

fun orpairwith(x,y) = let fun orgoodforpair(a, (Orset b)) = 
                            if isnil(b) then (Orset []) 
                            else (Orset (map (fn z => Prod(a,z)) b)) |
                            orgoodforpair _ = raise Badtyperho 
                      in orgoodforpair(x,y) end 

(* mapping *)

fun smap f x = let fun goodsmap f (Set b) = 
                       if isnil(b) then (Set [])
                       else (Set (remdupl(map f b))) |
                       goodsmap f _ = raise Badtypemap 
               in goodsmap f x end 

fun orsmap f x = let fun goodorsmap f (Orset b) = 
                       if isnil(b) then (Orset [])
                       else (Orset (remdupl(map f b))) |
                       goodorsmap f _ = raise Badtypemap 
               in goodorsmap f x end 

(* finally, alpha *) 
exception Iscrewedup


fun alpha(x) = let fun goodforalpha(Set y) = 
                   let fun reduced_list_alpha(x) = let fun listalpha [] = [] |
                                    listalpha (a::nil) = map (fn z => [z]) a |
                                    listalpha (a::b::x) = 
                                     let fun mapadd([],_) = [] |
                                             mapadd(_,[]) = [] |
                                             mapadd(a::x,b::y) = mapadd(x,y) @
                                             ((map (fn z => z@[a]) y)) @
                                             ((map (fn z => b@[z]) (a::x))) in 
                                if isnil(a) then [] 
                                else mapadd(a, listalpha(b::x)) end 
                            in (map remdupl (listalpha(x))) end 
                       fun isorset(Orset z) = true |
                           isorset _        = false
                       fun islistoforsets [] = true |
                           islistoforsets (a::z) = isorset(a)
                       fun mklistoflists [] = [] |
                           mklistoflists((Orset z)::w) =
                                     z::mklistoflists(w) |
                           mklistoflists _ = raise Iscrewedup
                    in if islistoforsets(y) then 
                          (Orset (remdupl
                          (map (fn z => (Set z))
                          (reduced_list_alpha(mklistoflists(y))))))
                       else raise Badtypealpha end |
                       goodforalpha _ = raise Badtypealpha 
               in goodforalpha(x) end 
                 
                       
(* booleans *)
fun neg(Bool b) = (Bool (not b)) | neg _ = raise Badtypebool 

fun cond(condition,first,second)(x) = 
   let fun getvalue (Bool b) = b | getvalue _ = raise Badtypecond 
in if getvalue(condition) then first(x) else second(x) end 
 

(* functions for naturals: plus, mult, monus, bigsum and gen *)

fun plus(x,y) = let fun goodforplus(Num i, Num j) = (Num (i+j)) |
                        goodforplus _      = raise Badtypearith
                in goodforplus(x,y) end 

fun mult(x,y) = let fun goodformult(Num i, Num j) = (Num (i*j)) |
                        goodformult _      = raise Badtypearith
                in goodformult(x,y) end 

fun monus(x,y) = let fun goodformonus(Num i, Num j) = if i < j 
                                      then (Num 0) 
                                      else (Num (i - j)) |
                         goodformonus _ = raise Badtypearith
                 in goodformonus(x,y) end 

fun gen(x) = let fun goodforgen(Num n) = 
                     let fun tmplist(n) = if n=0 then [0] else n::tmplist(n-1)
                     in mksetint(tmplist(n)) end |
                     goodforgen _ = raise Badtypearith
             in goodforgen(x) end 

fun sum f x = let fun goodforsum(Set y) = 
                        let fun sumf [] = (Num 0) |
                        sumf (a::z) = plus(f(a),sumf(z)) 
                      in sumf(y) end |
                      goodforsum _ = raise Badtypearith 
              in goodforsum(x) end        

fun orsum f x = let fun goodfororsum(Orset y) = 
                        let fun sumf [] = (Num 0) |
                        sumf (a::z) = plus(f(a),sumf(z)) 
                      in sumf(y) end |
                      goodfororsum _ = raise Badtypearith 
              in goodfororsum(x) end 

(* more derived functions: rhoone and cartprod *)

fun rhoone(x,y) = smap (p2 pair p1) (pairwith(y,x)) 
(* fun orrhoone(x,y) = orsmap (p2 pair p1) (orpairwith(y,x))  *)

fun orrhoone (x,y) = 
    let fun goodor1 ((Orset z), y) = 
	      (Orset (map (fn v => (Prod (v,y))) z)) |
	    goodor1 (_, y) = raise Badtyperho
    in goodor1 (x,y) end 


fun cartprod(x,y) = flat(smap (fn z => rhoone(p1(z),p2(z))) (pairwith(x,y))) 

fun orcartprod(x,y) = orflat(orsmap 
                               (fn z => orrhoone(p1(z),p2(z))) 
                               (orpairwith(x,y))) 


(* normalization *)
(* strategy: find the deepest instance of application of a rule *)
(* apply the rule and go up *)

(* if canapply(typeof(x)) then norm(x) else x *)

(* datatype of annotated complex objects *)

datatype LOCAL_COND = STOP | GO;

datatype ACO = DONE |
               UNode of unit |
               STRNode of string |
               NNode of int |
               BLNode of bool |
               BNode of base |
               PNode of LOCAL_COND ref * (ACO * ACO) |
               SNode of LOCAL_COND ref *  ACO list   |
               ONode of LOCAL_COND ref * ({vl : ACO, cn : bool ref}) list;

(* initial annotation *)

fun initial (Base b) = (BNode b) |
    initial (Num n) = (NNode n) |
    initial (Unit) = (UNode ()) |
    initial (Bool b) = (BLNode b) |
    initial (Str s) = (STRNode s) |
    initial (Prod (x,y)) = (PNode (ref GO,((initial x),(initial y)))) |
    initial (Set x) = (SNode (ref GO, (map initial x))) |
    initial (Orset x) = let fun make_ann_orset [] = [] |
                                make_ann_orset (a::y) = 
                                ({vl = (initial a),cn = ref true})::
                                (map (fn z => {vl = (initial z),
                                               cn = ref false}) y)
                      in (ONode (ref GO,(make_ann_orset x))) end;

(* given an annotation, pick a normal form entry *)

val VOID = (Str "void");

fun isvoid (Str s) = (s = "void") | isvoid _ = false;

fun pick DONE = (Str "void") |
    pick (UNode _) = Unit |
    pick (NNode i) = (Num i) |
    pick (STRNode s) = (Str s) |
    pick (BLNode b) = (Bool b) |
    pick (BNode n) = (Base n) |
    pick (PNode (cnd,(x,y))) = if !cnd = STOP then VOID else 
                                let val px = pick x
                                    val py = pick y 
                                in if isvoid px orelse isvoid py then VOID
                                  else (Prod (px,py))
                                end | 
    pick (SNode (cnd,x)) = if !cnd = STOP then VOID else 
                            let val px = (map pick x) 
                                 fun find_void [] = false |
                                     find_void (a::x) = 
                                         (isvoid a) orelse find_void x
                            in 
                    if find_void(px) then VOID else (Set px) end | 
    pick (ONode (cnd,x)) = if !cnd = STOP then VOID else 
                            let fun pick_true [] = DONE |
                                    pick_true (({vl = v,cn = c})::x) = 
                                   if !c then v else pick_true x
                           in (pick (pick_true x)) end;


(* now we need new resetting *)


fun reset (PNode (c,(x,y))) = (c := GO; if reset x
                                then reset y else false) | 
    reset (SNode (c,x)) = (c := GO; (map reset x); true) | 
    reset (ONode (c,x)) = (c := GO; 
            if isnil x then true else 
            ((map (fn ({vl=y,cn=z}) => (z:=true, reset y)) [hd(x)]);   
        (map (fn ({vl=y,cn=z}) => (z:=false; reset y)) (tl x)); true)) |
    reset _ = true;



(* now, status of object *)

fun status (PNode (c,_)) = !c |
    status (SNode (c,_)) = !c |
    status (ONode (c,_)) = !c |
    status _ = STOP;

fun changefirst [] = () |
    changefirst ({vl=a,cn=c}::x) = (c:=true; ());

fun next (PNode (c,(x,y))) = 
       if next(y) then (c := GO; true) else
        if next(x) then (c := GO; reset y; true) 
         else (c := STOP; false) |
    next (ONode (c,[])) = (c := STOP; false) | 
    next (ONode (c, X as ({vl=v,cn=cnd}::x))) = 
        if !c = STOP then (reset(ONode(c,X)); true) else
          if not(!cnd) then next (ONode (c,x)) else
            if next(v) then true else
               if isnil(x) then (c := STOP; false) else
                  (cnd := false; changefirst x; true) |
    next (SNode (c,[])) = (c := STOP; false) |
    next (SNode (c,(a::x))) = 
         if !c = STOP then false else
           if next(a) then true else 
             (reset a; next (SNode (c,x))) |
    next _ = false;




     
(* now that we have next, we do norm *)

fun istrue (Bool b) = b | istrue _ = false;

fun norm (cond, init, update, out) obj = 
    let val aobj = initial obj
	fun norm_aux (acc) =
	    let val new_acc = update ((pick aobj),acc)
	        val done = not(next aobj)
   	    in if (done orelse istrue(cond(pick aobj)))
               then out ((pick aobj,done),new_acc)
	       else norm_aux (new_acc)
	    end
in 
   if (status aobj = STOP) then out ((pick aobj,true),init) 
   else norm_aux (init)
end;

fun norm_nocond (init, update, out) obj = 
    let val aobj = initial obj
	fun norm_aux (acc) =
	    let val new_acc = update ((pick aobj),acc)
	        val done = not(next aobj)
   	    in if done 
               then out ((pick aobj,done),new_acc)
	       else norm_aux (new_acc)
	    end
in 
   if (status aobj = STOP) then out ((pick aobj,true),init) 
   else norm_aux (init)
end;

fun norm_simple (init, update, out) obj = 
    let val aobj = initial obj
	fun norm_aux (acc) =
	    let val new_acc = update (acc)
	        val done = not(next aobj)
   	    in if done 
               then out ((pick aobj,done),new_acc)
	       else norm_aux (new_acc)
	    end
in 
   if (status aobj = STOP) then out ((pick aobj,true),init) 
   else norm_aux (init)
end;



val exists = fn p => norm(p,false,(fn _ => false),(fn (v,w) => v));

(* the usual normalization *)

fun normal_entries obj = if canapply(typeof obj) then 
             norm ((fn _ => mkboolco(false)), (Orset []),
                   (fn (x,Y) => let val (Orset Y1) = Y 
                                in (Orset (x::Y1)) end),
                   (fn (_,y) => y)) obj 
             else obj;

fun isvoidorset (Orset []) = true |
    isvoidorset (Orset (a::x)) = isvoid a |
    isvoidorset _ = false;

fun normal_form x = 
             if canapply(typeof x) then 
               let val nentries = full_remdupl (normal_entries x)
                   in 
                      if isvoidorset nentries then (Orset []) 
                      else nentries
                   end
             else x;


(* now, normalization with timer *)

(* fun normal_time (obj, f,  tm) =
   let open System.Timer  
       val t = start_timer() 
       val totaltime = ref 0.0  
       fun getrealtime x =  
            let val TIME {sec = Z, usec = z} = (check_timer x) 
            in (real Z) + (real z)/(real 1000000) end 
       fun update (nfe,(bestentry,bestvalue)) =  
           let val newvalue = f nfe  
           in if newvalue < bestvalue then (bestentry,bestvalue) 
              else (nfe,newvalue) 
           end 
in 
  norm ( (fn _ => mkboolco (getrealtime(t) > tm)),  
         ((Num 1),0.0),                       
	 update,                            
	 (fn ((_,bv),acc) => let val (accobj,_) = acc in (accobj,bv) end)) 
  obj 
end *) 

local open System.Timer in 
fun normal_time (obj, f, compare, tm) =
   let val t = start_timer()
       fun getrealtime x = 
            let val TIME {sec = Z, usec = z} = (check_timer x)
            in (real Z) + (real z)/(real 1000000) end
       fun update (nfe,(bestentry,bestvalue)) = 
           let val newvalue = f nfe 
           in if compare(newvalue,bestvalue) then (bestentry,bestvalue)
              else (nfe,newvalue)
           end
       val fv = pick(initial obj)
       val fo = f fv
in 
  norm ( (fn _ => mkboolco (getrealtime(t) > tm)), (* for exit *)
         (fv, fo),                       (* initial value *)
	 update,                            (* update function *)
	 (fn ((_,bv),acc) => let val (accobj,_) = acc in (accobj,bv) end))
  obj
end
end


(* normal comes from version 2.3 *)

local
fun normal_aux(x) = 
let fun norm (Num i) = (Num i) |
        norm (Base b) = (Base b) |
        norm (Unit)   = (Unit)   |
        norm (Bool b) = (Bool b) |
        norm (Str s)  = (Str s)  |
        norm (Prod(x,y)) = if canapply(typeof(x)) andalso 
                              canapply(typeof(y)) then 
                                  orcartprod(norm(x),norm(y))
                               else if canapply(typeof(x)) then 
                                       orrhoone(norm(x),y)
                                    else if canapply(typeof(y))
                                         then orpairwith(x,norm(y))
                                         else (Prod(x,y)) |
        norm (Set y) = let fun issetofor([]) = false |
                               issetofor(a::z) = canapply(typeof(a))
                           in if isnil(y) then (Set y) 
                              else if issetofor(y) then
                                      alpha(smap norm (Set y)) 
                                   else (Set y) end |
        norm (Orset y) = let fun isfurthernorm([]) = false |
                                 isfurthernorm(a::z) = canapply(typeof(a))
                             in if isnil(y) then (Orset y) 
                                else if isfurthernorm(y) then 
                                     orflat (orsmap norm (Orset y)) 
                                     else (Orset y) end
in 
if canapply(typeof(x)) then norm(x) else x
end  

in 

fun normal x = let val dupelimval = !Dup_elim in
                  (Dup_elim := false; 
                   let val nx = normal_aux(x) 
                   in (Dup_elim := dupelimval; full_remdupl(nx)) end) 
               end
end



(* application of primitives on base type *)

exception Cannotapply;

(* application of unary primitives *)

fun apply_unary f = let fun app_f (Base b) = (Base (f b)) |
                            app_f _ = raise Cannotapply 
in app_f end;

(* applications of binary primitives *)

fun apply_binary f = let fun get_base (Base b) = b | 
                             get_base _ = raise Cannotapply
                         fun app_f (Prod(x,y)) = 
                                let val b1 = get_base x
                                    val b2 = get_base y
                                in (Base (f(b1,b2))) end |
                             app_f _ = raise Cannotapply 
in app_f end;

fun apply_op2 f = let fun get_base (Base b) = b | 
                          get_base _ = raise Cannotapply
                      fun app_f (x,y) =
                                let val b1 = get_base x
                                    val b2 = get_base y
                                in (Base (f(b1,b2))) end 
in app_f end;

(* application of arbitrary primitives *)

fun apply f = let
   fun get_base (Base b) = b |
       get_base _ = raise Cannotapply 
   fun long_prod_to_list (Prod(x,y)) = long_prod_to_list(x) @
                                       long_prod_to_list(y) |
       long_prod_to_list z = [z] 
in (fn z => (Base (f(map get_base (long_prod_to_list(z)))))) end;

(* application of a test *)

fun apply_test f = let
     fun app_f (Base b) = (Bool (f b)) | app_f _ = raise Cannotapply
in app_f end

(* Order for base types *)
(* strings *)
fun leqstring ((Str s1),(Str s2)) = (Bool (s1 < s2)) |
    leqstring _ = raise Badtypeleq 

fun leqbase ((Base b1),(Base b2)) = (Bool (BTS.leq_base(b1,b2))) |
    leqbase _ = raise Badtypeleq 

end 

end (* Functor ALGEBRA *)


(************** end of algebra **************)

