(*********************************************************)
(*                                                       *)
(*       WORKING WITH COMPLEX OBJECT TYPES               *)
(*                                                       *)
(*********************************************************)

signature typesig =
  sig
    structure Common : commonsig
    exception Dontunify
    val canapply : Common.co_type -> bool
    val normalize : Common.co_type -> Common.co_type
    val printtype : Common.co_type -> unit
    val tp_print : Common.co_type -> string
    val typeof : Common.co -> Common.co_type
    val unifytypes : Common.co_type * Common.co_type -> Common.co_type
  end


functor TYPE (Common: commonsig) : typesig = 
struct 

structure Common = Common 


local open Common in

fun canapply (Prodt(T1,T2)) = canapply T1 orelse canapply T2 |
    canapply (Sett T) = canapply T |
    canapply (Orsett T) = true |
    canapply _ = false 

fun isorset (Orsett T) = true |
    isorset  _        = false

(* get rid of the outermost or-set *)

fun stripone (Orsett T) = T | stripone x = x
             
(* print types nicely *)

fun tp_print t = let 
   fun collect_labels (Vart i) = [i] |
       collect_labels (Prodt(t1,t2)) = collect_labels(t1)@collect_labels(t2) |
       collect_labels (Sett t) = collect_labels(t) |
       collect_labels (Orsett t) = collect_labels(t) |
       collect_labels _ = []
   fun rass x = let fun lessthanz ([]:int list,z) = 0 |
                        lessthanz ((a::x),z) = if a <= z then 1+lessthanz(x,z)
                                        else lessthanz(x,z)
   in map (fn z => {fst = z, snd = lessthanz(x,z)}) x end
   fun getindex ([],n) = chr(96) |
       getindex (({fst = a1, snd = a2}::x),n) = 
                if a1 = n then chr(a2+96) else getindex(x,n)
   val lstoflbl = rass(collect_labels(t)) 
   fun printtype_internal Unitt = "unit"   |
    printtype_internal Baset = BTS.name_base  |
    printtype_internal (Vart i) = "'"^getindex(lstoflbl,i) |
    printtype_internal Boolt = "bool"   |
    printtype_internal Numt = "int"    |
    printtype_internal Strt  = "string" |
    printtype_internal (Prodt(t1,t2)) = printtype_internal(t1)^
                                    "*"^printtype_internal(t2) |
    printtype_internal (Sett t) = "{"^printtype_internal(t)^"}" |
    printtype_internal (Orsett t) = "<"^printtype_internal(t)^">"
in printtype_internal(t) end;

fun printtype t = (print(tp_print(t)); print "\n\n");

(* do one step of rewriting, whatever found first *)

fun rewrite (Prodt (T1,T2)) = if (canapply T1) andalso (canapply T2)
                              then (Orsett (Prodt 
                                  (stripone (rewrite T1), 
                                   stripone(rewrite T2))))  
                           else if (canapply T1) andalso (not (canapply T2))
                               then (Orsett (Prodt 
                                           (stripone(rewrite T1), T2)))
                           else if (canapply T2) andalso (not (canapply T1))
                               then (Orsett (Prodt (T1, stripone(rewrite T2))))
                           else (Prodt (T1, T2)) |
    rewrite (Sett T) = if canapply T then Orsett (Sett (stripone(rewrite T)))
                       else (Sett T) |
    rewrite (Orsett T) = if (isorset T) then (Orsett (stripone T))
                      else (Orsett (rewrite T)) |
    rewrite x = x

(* is it one orset, that is, normal form *)

fun oneorset (Orsett T) = not (canapply T) |
    oneorset _         = false

(* normalization *)

fun normalize T = if (not (canapply T)) orelse (oneorset T) 
                     then T
                  else normalize (rewrite T)

(* now define function typeof which returns type *)
(* we need unification first. when we try to infer the type of *)
(* a set, we infer type of each element and then simply unify all *)

exception Dontunify and Nondetectedemptyset 

fun unifytypes ((Vart i),t) = t |
    unifytypes (t,(Vart i)) = t |
    unifytypes ((Prodt(t1,t2)),(Prodt(tt1,tt2))) = (Prodt(unifytypes(t1,tt1),
                                                     unifytypes(t2,tt2))) |
    unifytypes ((Sett t1),(Sett t2)) = (Sett (unifytypes(t1,t2))) |
    unifytypes ((Orsett t1),(Orsett t2)) = (Orsett (unifytypes(t1,t2))) |
    unifytypes (Numt, Numt) = Numt |
    unifytypes (Unitt, Unitt) = Unitt |
    unifytypes (Boolt, Boolt) = Boolt |
    unifytypes (Strt, Strt)   = Strt  |
    unifytypes (Baset, Baset) = Baset | 
    unifytypes _ = raise Dontunify 

fun unifyset [] = raise Nondetectedemptyset |
    unifyset (a::nil) = a |
    unifyset (a::x) = unifytypes(a,unifyset(x)) 

(* gettypes now gets types of elements of the set *)

fun typeof(x) = 
  let val loc = ref 0 
      fun typeof1 (Base b) = Baset |
          typeof1 (Unit) = Unitt |
          typeof1 (Num i) = Numt |
          typeof1 (Bool b) = Boolt |
          typeof1 (Str s)  = Strt  |
          typeof1 (Prod(x,y)) = Prodt(typeof1(x),typeof1(y)) |
          typeof1 (Set x) = let fun gettypes [] = 
                                      (loc := !loc + 1; [(Vart (!loc))]) |
                             gettypes (a::y) = typeof1(a)::gettypes(y) 
                     in  (Sett (unifyset(gettypes(x)))) end |
          typeof1 (Orset x) = let fun gettypes [] = 
                                       (loc := !loc + 1; [(Vart (!loc))]) |
                             gettypes (a::y) = typeof1(a)::gettypes(y) 
                       in (Orsett (unifyset(gettypes(x)))) end
in 
 typeof1(x)
end

end

end (* Functor TYPE *)

(**********    END OF TYPE WORK  *************)



