

(*****************************************************)
(*                                                   *)
(*      DUPLICATE ELIMINATION                        *)
(*                                                   *)
(*****************************************************)

signature dupelimsig = 
  sig
    structure Common : commonsig
    exception Badtypebool
    val eq : Common.co * Common.co -> Common.co
    val remdupl : Common.co list -> Common.co list
    val full_remdupl : Common.co -> Common.co
  end


functor DUPELIM(Common: commonsig) : dupelimsig =
struct

structure Common = Common

local open Common in 
(* equality of complex objects *)

exception Badtypebool

fun eq(o1,o2) = 
    let fun contained eqpred (nil,_) = true |
            contained eqpred ((a::x),nil) = false |
            contained eqpred ((a::x),y) = 
                let fun eltof(c,nil) = false |
                        eltof(c,(d::z)) = eqpred(c,d) orelse eltof(c,z) 
            in eltof(a,y) andalso (contained eqpred (x,y)) end
        fun eq1(Unit,Unit) = true |
            eq1((Base b1),(Base b2)) = BTS.eq_base(b1,b2) |
            eq1((Bool b1),(Bool b2)) = (b1 = b2) |
            eq1((Num i),(Num j)) = (i = j) |
            eq1((Str s1),(Str s2)) = (s1 = s2) |
            eq1(Prod(x1,x2),Prod(y1,y2)) = eq1(x1,y1) andalso eq1(x2,y2) |
            eq1((Set x),(Set y)) = (contained (fn v => eq1(v)) (x,y))
                                          andalso 
                                   (contained (fn v => eq1(v)) (y,x)) |
            eq1((Orset x),(Orset y)) = (contained (fn v => eq1(v)) (x,y))
                                          andalso 
                                        (contained (fn v => eq1(v)) (y,x)) |
            eq1 _ = raise Badtypebool 
in (Bool (eq1(o1,o2))) end  


(* we need osrting for duplicate elimination *)
exception Cannotsort; 

 fun msort (op > : 'a * 'a -> bool) ls = let
          fun merge([],ys) = ys
            | merge(xs,[]) = xs
            | merge(x::xs,y::ys) =
                if x > y then y::merge(x::xs,ys) else x::merge(xs,y::ys)
          fun mergepairs(ls as [l], k) = ls
            | mergepairs(l1::l2::ls,k) =
                if k mod 2 = 1 then l1::l2::ls
                else mergepairs(merge(l1,l2)::ls, k div 2)
            | mergepairs _ = raise Cannotsort
          fun nextrun(run,[])    = (rev run,[])
            | nextrun(run,x::xs) = if x > hd run then nextrun(x::run,xs)
                                   else (rev run,x::xs)
          fun samsorting([], ls, k)    = hd(mergepairs(ls,0))
            | samsorting(x::xs, ls, k) = let
                val (run,tail) = nextrun([x],xs)
                in samsorting(tail, mergepairs(run::ls,k+1), k+1)
                end
          in
            case ls of [] => [] | _ => samsorting(ls, [], 0)
          end;

(* now determine size of an object *)

fun hash_co (Base b) = BTS.hash_base(b) |
    hash_co (Unit) = 1 |
    hash_co (Num i) = i |
    hash_co (Bool b) = if b then 1 else 2 |
    hash_co (Str s) = size s |
    hash_co (Prod(x,y)) = hash_co(x) + hash_co(y) |
    hash_co (Set x) = let fun sumlist [] = 0 |
                             sumlist (a::x) = hash_co(a) + sumlist(x)
                      in sumlist(x) end |
    hash_co (Orset x) = let fun sumlist [] = 0 |
                               sumlist (a::x) = hash_co(a) + sumlist(x)
                      in sumlist(x) end;

fun pairwithsize L = map (fn x => (x,hash_co(x))) L;

val sortbysize = msort (fn (x,y) => let val (_,(x1:int)) = x
                                        val (_,y1) = y
                                    in x1 > y1 end);

fun createlofl [] = [[]] |
    createlofl (a::nil) =[[a]] |
    createlofl (a::L) = let val (_,a1) = a
                            val (_,l1) = hd(L)
                        in if a1=l1 then (let val v = createlofl(L) in
                                             ([a]@hd(v))::(tl(v)) end)
                                    else [a]::(createlofl(L))
                        end;

fun isbooltrue (Bool b) = b | isbooltrue _ = false;

fun mem_new (a,[]) = false |
    mem_new (a,b::x) = let val (a1,_) = a
                           val (b1,_) = b in
                   if isbooltrue(eq(a1,b1)) then true else mem_new (a,x) 
                   end;

fun remdupl_naive [] = [] |
    remdupl_naive (a::x) = let val (a1,_) = a in
                           if mem_new(a,x) then remdupl_naive(x)
                              else a1::remdupl_naive(x)
                           end;

(* now duplicate elimination *)

fun remdupl_always x = flat_list(
                    map remdupl_naive 
                    (createlofl(sortbysize(pairwithsize x))));

fun remdupl x = if (!Dup_elim) then  remdupl_always x else x;


(* duplicate eleimination by equality *)

fun mem_plain (a,[]) = false |
    mem_plain (a,b::x) = if a=b then true else mem_plain (a,x)

fun remdupl_plain ([]) = [] |
    remdupl_plain (a::x) = if mem_plain(a,x) then remdupl_plain(x) 
                              else a::remdupl_plain(x)

(* recursive duplicate elimination *)
(* it will only be used with norm *)


fun full_remdupl (Base b) = (Base b) |
    full_remdupl (Unit) = Unit |
    full_remdupl (Num i) = (Num i) |
    full_remdupl (Bool b) = (Bool b) |
    full_remdupl (Str s) = (Str s) |
    full_remdupl (Prod(x,y)) = (Prod ((full_remdupl x),(full_remdupl y))) |
    full_remdupl (Set x) = (Set (remdupl_always(map full_remdupl x))) |
    full_remdupl (Orset x) = (Orset (remdupl_always(map full_remdupl x))); 




end

end (* functor DUPELIM *)

(************* end dup-elim ****************)
