(* ========================================================================= *)
(* Preterms and pretypes; typechecking; translation to types and terms.      *)
(* ========================================================================= *)

type pretype = Stv of int                      (* System type variable       *)
             | Utv of string                   (* User type variable         *)
             | Ptycon of string * pretype list (* Type constructor           *)
             | Link of pretype ref;;           (* Link for type unifier      *)

(* ------------------------------------------------------------------------- *)
(* Dummy pretype for the parser to stick in before a proper typing pass.     *)
(* ------------------------------------------------------------------------- *)

let Dpty = Ptycon("",[]);;

(* ------------------------------------------------------------------------- *)
(* Preterm syntax has antiquotation but the parser won't handle it.          *)
(* ------------------------------------------------------------------------- *)

type preterm = Varp of string * pretype       (* Variable           - v      *)
             | Constp of string * pretype     (* Constant           - c      *)
             | Combp of preterm * preterm     (* Combination        - f x    *)
             | Absp of preterm * preterm      (* Lambda-abstraction - \x. t  *)
             | Typing of preterm * pretype    (* Type constraint    - t : ty *)
             | Antiq of unit;;                (* Antiquotation      - ^mlexp *)

(* ------------------------------------------------------------------------- *)
(* Get the Utvs in a pretype.                                                *)
(* ------------------------------------------------------------------------- *)

let rec pretyvars pty acc =
  match pty with
    Stv(_) -> acc
  | Utv(_) -> insert pty acc
  | Ptycon(_,ptys) -> itlist pretyvars ptys acc
  | Link (ref t) -> pretyvars t acc;;

(* ------------------------------------------------------------------------- *)
(* Scrub the link fields from a pretype.                                     *)
(* ------------------------------------------------------------------------- *)

let rec clean_pretype pty =
  match pty with
    Stv(_) -> pty
  | Utv(_) -> pty
  | Ptycon(s,ptys) -> Ptycon(s,map clean_pretype ptys)
  | Link (ref t) -> clean_pretype t;;

(* ------------------------------------------------------------------------- *)
(* Useful to be able to whiz down through typings to a variable.             *)
(* ------------------------------------------------------------------------- *)

let rec dive_to_var ptm =
  match ptm with
    Varp(_,_) as vp -> vp | Typing(t,_) -> dive_to_var t | _ -> fail();;

let var_name = fun (Varp(n,_)) -> n | _ -> failwith "var_name: Non-variable";;

(* ------------------------------------------------------------------------- *)
(* Flag to indicate that Stvs were translated to real type variables.        *)
(* ------------------------------------------------------------------------- *)

let stvs_translated = ref false;;

(* ------------------------------------------------------------------------- *)
(* Pretype <-> type conversion; -> flags system type variable translation.   *)
(* ------------------------------------------------------------------------- *)

let rec type_of_pretype ty =
  match ty with
    Stv n -> stvs_translated := true;
             let s = "?"^(string_of_int n) in
             mk_vartype(s)
  | Utv(v) -> mk_vartype(v)
  | Ptycon(con,args) -> mk_type(con,map type_of_pretype args)
  | Link (ref ty) -> type_of_pretype ty;;

let rec pretype_of_type ty =
  try let con,args = dest_type ty in
      Ptycon(con,map pretype_of_type args)
  with Failure _ -> Utv(dest_vartype ty);;

(* ------------------------------------------------------------------------- *)
(* Pretype substitution for a pretype resulting from translation of type.    *)
(* ------------------------------------------------------------------------- *)

let rec pretype_subst th ty =
  match ty with
    Ptycon(tycon,args) -> Ptycon(tycon,map (pretype_subst th) args)
  | Utv v -> (try rev_assoc (Utv v) th with Failure _ -> Utv v)
  | _ -> failwith "pretype_subst: Unexpected form of pretype";;

(* ------------------------------------------------------------------------- *)
(* Storage for overloading: type skeletons and various translations.         *)
(* ------------------------------------------------------------------------- *)

let the_overloads =
  ref ([] :(string * (pretype * (string * pretype) list)) list);;

let the_overloaded =
  ref ([] :(string * string) list);;

let make_overloadable s gty =
  if can (assoc s) (!the_overloads)
  then failwith "make_overloadable: Already overloaded"
  else the_overloads := (s,(pretype_of_type gty,[]))::(!the_overloads);;

let overload oname realname ty =
  let (_,(gty,olist)),others =
    remove (fun (s,_) -> s = oname) (!the_overloads) in
  if not can (type_match (type_of_pretype gty) ty) []
  then failwith "overload: Doesn't match type skeleton"
  else if not tyvars ty = []
  then failwith "overload: Can't overload polymorphic constants" else
  let pty = pretype_of_type ty in
  (if can (rev_assoc pty) olist
   then failwith "overload: Overload name and type already used" else
   the_overloads := (oname,(gty,(realname,pty)::olist))::(!the_overloads));
  if oname = realname then ()
  else the_overloaded := (realname,oname)::(!the_overloaded);;

let unoverload oname realname ty =
  let (_,(gty,olist)),others =
    remove (fun (s,_) -> s = oname) (!the_overloads) in
  let pty = pretype_of_type ty in
  let olist' = filter (fun (s,_) -> not (s = realname)) olist in
  the_overloads := (oname,(gty,olist'))::(!the_overloads);
  the_overloaded :=
    filter (fun (s,_) -> not (s = realname)) (!the_overloaded);;

let prioritize_overload oname realname ty =
  (try unoverload oname realname ty with Failure _ -> ());
  overload oname realname ty;;

let get_overloads s = snd(assoc s (!the_overloads));;

let get_overload_skeleton s = fst(assoc s (!the_overloads));;

(* ------------------------------------------------------------------------- *)
(* Return new system type variable; reset the internal counter used.         *)
(* ------------------------------------------------------------------------- *)

let new_type_var,reset_type_num =
  let tyv_num = ref 0 in
  (fun () -> let n = !tyv_num in
              (tyv_num := n + 1; Link(ref(Stv(n))))),
  (fun () -> tyv_num := 0);;

(* ------------------------------------------------------------------------- *)
(* Gets a new substitution instance of constant's generic pretype.           *)
(* ------------------------------------------------------------------------- *)

let get_generic_pretype cname =
  let gty = try get_overload_skeleton cname
            with Failure _ -> pretype_of_type(get_const_type cname) in
  let tyvs = pretyvars gty [] in
  let subs = map (fun tv -> new_type_var(),tv) tyvs in
  pretype_subst subs gty;;

(* ------------------------------------------------------------------------- *)
(* Handle constant hiding.                                                   *)
(* ------------------------------------------------------------------------- *)

let hide_constant,unhide_constant,is_hidden =
  let hcs = ref ([]:string list) in
  let hide_constant c = hcs := union [c] (!hcs)
  and unhide_constant c = hcs := subtract (!hcs) [c]
  and is_hidden c = mem c (!hcs) in
  hide_constant,unhide_constant,is_hidden;;

(* ------------------------------------------------------------------------- *)
(* Construct the internal representation of a numeral.                       *)
(* ------------------------------------------------------------------------- *)

let pmk_numeral =
  let num_pty = Ptycon("num",[]) in
  let NUMERAL = Constp("NUMERAL",Ptycon("fun",[num_pty; num_pty]))
  and BIT0 = Constp("BIT0",Ptycon("fun",[num_pty; num_pty]))
  and BIT1 = Constp("BIT1",Ptycon("fun",[num_pty; num_pty]))
  and t_0 = Constp("_0",num_pty) in
  let rec pmk_numeral(n) =
    if n =/ Int 0 then t_0 else
    let m = quo_num n (Int 2) and b = mod_num n (Int 2) in
    let op = if b =/ Int 0 then BIT0 else BIT1 in
    Combp(op,pmk_numeral(m)) in
  fun n -> Combp(NUMERAL,pmk_numeral n);;

(* ------------------------------------------------------------------------- *)
(* Compensation for inadequate error checking (accepts "" and "-").          *)
(* ------------------------------------------------------------------------- *)

let num_of_string s =
  if not s = "" & forall (fun s -> "0" <= s & s <= "9") (explode s)
  then num__num_of_string s
  else failwith "num_of_string";;

(* ------------------------------------------------------------------------- *)
(* Traverse preterm, recognizing constants and affixing system types to      *)
(* constants and variables. This function takes a list of variable bindings  *)
(* (bound variables override constant status) and an environment to allow    *)
(* linkage of the same name in different parts of the term.                  *)
(* ------------------------------------------------------------------------- *)

let rec typify env ptm =
  match ptm with
    Varp(s,_) -> (try Varp(s,assoc s env),[] with Failure _ ->
                  try pmk_numeral(num_of_string s),[] with Failure _ ->
                  try if is_hidden s then fail()
                      else Constp(s,get_generic_pretype s),[]
                  with Failure _ ->
                      let nty = new_type_var() in Varp(s,nty),[s,nty])
  | Combp(f,x) -> let ftm,nenv = typify env f in
                  let xtm,eenv = typify (nenv@env) x in
                  Combp(ftm,xtm),eenv@nenv
  | Absp(v,bod) -> let vv,venv = typify [] v in
                   let tm,eenv = typify (venv@env) bod in
                   Absp(vv,tm),eenv
  | Typing(tm,ty) -> let ptm,nenv = typify env tm in
                     Typing(ptm,ty),nenv
  | _ -> failwith "typify: Unexpected preterm class";;

(* ------------------------------------------------------------------------- *)
(* Occurs check. The "trivial" flag is maintained "true" as long as we've    *)
(* just gone through links, to avoid occurs check failure for "t = t".       *)
(* ------------------------------------------------------------------------- *)

let rec occurs_check trivial v ty =
  match ty with
    Ptycon(_,args) -> exists (occurs_check false v) args
  | Link(ref ty1) -> occurs_check trivial v ty1
  | _ -> v = ty & not trivial;;

(* ------------------------------------------------------------------------- *)
(* Side-effecting unification of types.                                      *)
(*                                                                           *)
(* The link ref allows us to perform a substitution everywhere simply by     *)
(* assigning to the ref, which is effectively a double dereference.          *)
(* ------------------------------------------------------------------------- *)

let rec unify ty1 ty2 =
  if ty1 = ty2 then () else
  match(ty1,ty2) with
    (Link(r),ty2) ->
        (match !r with
           Stv(v) as sv -> if occurs_check true sv ty2 then
                             failwith "unify: Occurs check failure"
                           else (if occurs_check false sv ty2
                                 then () else r := ty2)
         | ty ->  unify ty ty2)
  | (ty1,(Link(r) as lr)) -> unify lr ty1
  | (Ptycon(s1,args1),Ptycon(s2,args2)) ->
        if s1 = s2 then do_list2 unify args1 args2
        else failwith "unify: Attempt to unify distinct constructors";;

(* ------------------------------------------------------------------------- *)
(* Preterm typechecker.                                                      *)
(*                                                                           *)
(* The "chase" function finds the codomain type of a function type (to type  *)
(* a comb we only need to do this to the type of the rator). It also follows *)
(* links as far as it needs to do to get to a Ptycon constructor.            *)
(*                                                                           *)
(* TC is a side-effecting typechecker which returns the pretype of the       *)
(* preterm it is given.                                                      *)
(* ------------------------------------------------------------------------- *)

let typecheck =
  let rec chase typ =
    match typ with
      Ptycon("fun",[_;ty]) -> ty
    | Link(ref ty) -> chase ty
    | _ -> failwith "chase: Function type expected" in
  let rec TC ptm =
    match ptm with
      Varp(_,ty) -> ty
    | Constp(_,ty) -> ty
    | Combp(f,x) -> let fty = TC f in
                   (unify fty (Ptycon("fun",[TC x; new_type_var()])); chase fty)
    | Absp(bv,bod) -> Ptycon("fun",[TC bv; TC bod])
    | Typing(tm,ty) -> (unify (TC tm) ty; ty)
    | Antiq _ -> failwith "TC: Can't handle antiquotes" in
  fun ptm -> (TC ptm; ptm);;

(* ------------------------------------------------------------------------- *)
(* Resolve overloaded identifiers according to assigned type, if possible.   *)
(* ------------------------------------------------------------------------- *)

let overloads_defaulted = ref false;;

let pmk_cv(s,pty) =
  if can get_const_type s then Constp(s,pty)
  else Varp(s,pty);;

let rec overresolve ptm =
  match ptm with
    Varp(_,ty) -> ptm
  | Constp(s,pty) -> (try let olist = get_overloads s in
                          try let s' = rev_assoc (clean_pretype pty) olist in
                              pmk_cv(s',pty)
                          with Failure _ ->
                              overloads_defaulted := true;
                              let (s',pty') = hd olist in
                              Typing(pmk_cv(s',pty),pty')
                      with Failure _ -> ptm)
  | Combp(f,x) -> Combp(overresolve f,overresolve x)
  | Absp(bv,bod) -> Absp(bv,overresolve bod)
  | Typing(tm,ty) -> Typing(overresolve tm,ty)
  | Antiq _ -> failwith "TC: Can't handle antiquotes";;

(* ------------------------------------------------------------------------- *)
(* Rerun typechecking if necessary to resolve overloaded identifiers.        *)
(* ------------------------------------------------------------------------- *)

let retypecheck env ptm =
  overloads_defaulted := false;
  let ptm' = overresolve (typecheck (fst (typify env ptm))) in
  if !overloads_defaulted then typecheck ptm' else ptm';;

(* ------------------------------------------------------------------------- *)
(* Eliminates dollared identifiers (expects no constants!)                   *)
(* ------------------------------------------------------------------------- *)

let rec undollar ptm =
  match ptm with
    Varp(n,ty) -> (match (explode n) with
                    ["$"] -> failwith "Empty dollared identifier"
                  | "$"::rest -> Varp(implode rest,ty)
                  | _ -> ptm)
  | Absp(bv,t) -> Absp(undollar bv,undollar t)
  | Combp(t1,t2) -> Combp(undollar t1,undollar t2)
  | Typing(t,ty) -> Typing(undollar t,ty)
  | _ -> failwith "Unexpected constant in a pre-preterm";;

(* ------------------------------------------------------------------------- *)
(* Maps preterms to terms.                                                   *)
(* ------------------------------------------------------------------------- *)

let term_of_preterm =
  let rec term_of_preterm ptm =
    match ptm with
      Varp(s,pty) -> mk_var(s,type_of_pretype pty)
    | Constp(s,pty) -> mk_mconst(s,type_of_pretype pty)
    | Combp(l,r) -> mk_comb(term_of_preterm l,term_of_preterm r)
    | Absp(v,bod) -> mk_gabs(term_of_preterm v,term_of_preterm bod)
    | Typing(ptm,pty) -> term_of_preterm ptm
    | Antiq _ -> failwith "term_of_preterm: Can't handle antiquote" in
  fun ptm -> stvs_translated := false;
             let tm = term_of_preterm ptm in
             warn (!stvs_translated) "inventing type variables"; tm;;

(* ------------------------------------------------------------------------- *)
(* Maps the other way (used by the printer, for example).                    *)
(* ------------------------------------------------------------------------- *)

let rec preterm_of_term tm =
  try let n,ty = dest_var tm in
      Varp(n,pretype_of_type ty)
  with Failure _ -> try
      let n,ty = dest_const tm in
      Constp(n,pretype_of_type ty)
  with Failure _ -> try
      let v,bod = dest_abs tm in
      Absp(preterm_of_term v,preterm_of_term bod)
  with Failure _ ->
      let l,r = dest_comb tm in
      Combp(preterm_of_term l,preterm_of_term r);;

(* ------------------------------------------------------------------------- *)
(* A few extra operations needed on preterms as well as terms.               *)
(* ------------------------------------------------------------------------- *)

let rec pfrees ptm acc =
  match ptm with
    Varp(v) -> if v = ("",Dpty) then acc else ptm::acc
  | Constp(_) -> acc
  | Combp(p1,p2) -> pfrees p1 (pfrees p2 acc)
  | Absp(p1,p2) -> subtract (pfrees p2 acc) (pfrees p1 [])
  | Typing(p,_) -> pfrees p acc
  | Antiq(_) -> acc;;

let pgenvar =
  let gcounter = ref 0 in
  fun () -> let count = !gcounter in
            (gcounter := count + 1;
             Varp("GEN%PVAR%"^(string_of_int count),Dpty));;

let rec split_ppair ptm =
  match ptm with
    Combp(Combp(Varp(",",Dpty),ptm1),ptm2) -> ptm1::(split_ppair ptm2)
  | _ -> [ptm];;

let pmk_eq(ptm1,ptm2) =
  Combp(Combp(Varp("=",Dpty),ptm1),ptm2);;

let pmk_conj(ptm1,ptm2) =
  Combp(Combp(Varp("/\\",Dpty),ptm1),ptm2);;

let pmk_exists(v,ptm) =
  Combp(Varp("?",Dpty),Absp(v,ptm));;

let pmk_let =
  let pdest_eq (Combp(Combp(Varp("=",_),l),r)) = l,r in
  fun (letbindings,body) ->
    let vars,tms = unzip (map pdest_eq letbindings) in
    let lend = Combp(Varp("LET_END",Dpty),body) in
    let abs = itlist (fun v t -> Absp(v,t)) vars lend in
    let labs = Combp(Varp("LET",Dpty),abs) in
    rev_itlist (fun x f -> Combp(f,x)) tms labs;;

let pmk_set_enum ptms =
  itlist (fun x t -> Combp(Combp(Varp("INSERT",Dpty),x),t)) ptms
         (Varp("EMPTY",Dpty));;

let pmk_setabs (fabs,babs) =
  let evs = intersect (pfrees fabs []) (pfrees babs []) in
  let v = pgenvar() in
  let bod = itlist (curry pmk_exists) evs (pmk_conj(babs,pmk_eq(v,fabs))) in
  Combp(Varp("GSPEC",Dpty),Absp(v,bod));;

let pmk_vbinder(n,v,bod) =
  if n="\\" then Absp(v,bod)
  else Combp(Varp(n,Dpty),Absp(v,bod));;

let pmk_binder(n,vs,bod) =
  itlist (fun v b -> pmk_vbinder(n,v,b)) vs bod;;

let pmk_list els =
  itlist (fun x y -> Combp(Combp(Varp("CONS",Dpty),x),y))
         els (Varp("NIL",Dpty));;
