(*  Title: 	type
    Author: 	Lawrence C Paulson, Cambridge University Computer Laboratory
    Copyright   Cambridge University 1988

Type inference (for input of terms)
*)

signature TYPE = 
sig
  structure Symtab: SYMTAB
  val gen_tyvar: unit -> typ
  val infer1: typ list * term * (string*typ)list -> typ * (string*typ)list
  val infer_term: typ Symtab.table * typ Symtab.table * typ * term -> term
  val init: unit -> unit
  val inst: (string * typ) list -> typ -> typ
  val inst_term: (string * typ) list -> term -> term
  val occ: string * typ * (string * typ) list -> bool
  val unify: typ * typ * (string * typ) list -> (string * typ) list
end;

functor TypeFun(Symtab: SYMTAB) : TYPE = 
struct
structure Symtab = Symtab;


val tyvar_count = ref 0;

fun init() = (tyvar_count := 0);

fun gen_tyvar() = 
    (tyvar_count := !tyvar_count+1;
     Poly ("." ^ string_of_int (!tyvar_count)));


(*Occurs check: type variable occurs in type?*)
fun occ (a, Ground _, tye) = false
  | occ (a, Poly b, tye) = 
	if a=b then true
	else (case assoc(tye,b) of
	          None   => false
	        | Some U => occ(a,U,tye))
  | occ (a, T-->U, tye) = occ(a,T,tye)  orelse  occ(a,U,tye);


(*Raised if types are not unifiable*)
exception TUNIFY;

(*Chase variable assignments in tye.  
  If devar (T,tye) returns a type var then it must be unassigned.*) 
fun devar (Poly a, tye) =
      (case  assoc(tye,a)  of
          Some U =>  devar (U,tye)
        | None   =>  Poly a)
  | devar (T,tye) = T;


(*Unification of types*)
fun unify_var (Poly a, U, tye) =  if U = Poly a  then  tye
         else if occ(a,U,tye) then  raise TUNIFY  else  (a,U)::tye
  | unify_var (T,U,tye) = unify(T,U,tye)
and unify (Poly a, T, tye) = unify_var (devar(Poly a, tye), devar(T,tye), tye)
  | unify (T, Poly a, tye) = unify_var (devar(Poly a, tye), devar(T,tye), tye)
  | unify (Ground a, Ground b, tye) =
        if a=b then tye  else  raise TUNIFY
  | unify (T1-->T2, U1-->U2, tye) =  
        unify (T2, U2, unify (T1, U1, tye))
  | unify _ =  raise TUNIFY;



(*Instantiation of type variables in types*)
fun inst tye (Ground a) = Ground a
  | inst tye (Poly a) = 
      (case  assoc(tye,a)  of
	  Some U =>  inst tye U
	| None   =>  Poly a)
  | inst tye (T-->U) = inst tye T --> inst tye U;

(*Instantiation of type variables in terms
    Delete explicit constraints -- occurrences of ".constrain" *)
fun inst_term tye (Const(a,T)) = Const(a, inst tye T)
  | inst_term tye (Free(a,T)) = Free(a, inst tye T)
  | inst_term tye (Var(v,T)) = Var(v, inst tye T)
  | inst_term tye (Bound i)  = Bound i
  | inst_term tye (Abs(a,T,t)) = Abs(a, inst tye T, inst_term tye t)
  | inst_term tye (Const(".constrain",_) $ t) = inst_term tye t
  | inst_term tye (f$t) = inst_term tye f $ inst_term tye t;


(*Type inference for polymorphic term*)
fun infer1 (Ts, Const (_,T), tye) = (T,tye)
  | infer1 (Ts, Free  (_,T), tye) = (T,tye)
  | infer1 (Ts, Bound i, tye) = ((nth_elem(i,Ts) , tye)
      handle LIST _=> raise TYPE ("loose bound variable", [], [Bound i]))
  | infer1 (Ts, Var (_,T), tye) = (T,tye)
  | infer1 (Ts, Abs (_,T,body), tye) = 
	let val (U,tye') = infer1(T::Ts, body, tye)
	in  (T-->U, tye')  end
  | infer1 (Ts, f$u, tye) = 
	let val (U,tyeU) = infer1 (Ts, u, tye);
	    val (T,tyeT) = infer1 (Ts, f, tyeU)
	in (case T of
	      T1-->T2 => (T2, unify(T1, U, tyeT))
	    | Poly _ => 
		let val T2 = gen_tyvar()
		in  (T2, unify(T, U-->T2, tyeT))  end
	    | _ => raise TYPE ("Rator must have function type",
				  [inst tyeT T, inst tyeT U], [f$u]))
	  handle TUNIFY => raise TYPE
	     ("type mismatch in application", 
	      [inst tyeT T, inst tyeT U], [f$u])
	end;


(*Attach a type to a constant.  == is the only constant requiring a new tyvar*)
fun type_const (a,T) =
    if a = "=="  then  equals(Type.gen_tyvar())  
    else   Const(a,T);


(*Find type of ident.  If not in table then use ident's name for tyvar
  to get consistent typing.*)
fun id_type (itab, a) =
    case Symtab.lookup(itab, a) of
	Some T => T
      | None => Poly a;


(*Attach types to a term.  Input is a "parse tree" containing dummy types*)
fun add_types (const_tab, ident_tab) =
  let fun add (Const(a,_)) =
	    (case Symtab.lookup(const_tab, a) of
		Some T => type_const(a,T)
	      | None => raise TYPE ("No such constant: "^a, [], []))
	| add (Bound i) = Bound i
	| add (Free(a,_)) =
	    (case Symtab.lookup(const_tab, a) of
		Some T => type_const(a,T)
	      | None => Free(a, id_type(ident_tab, a)))
	| add (Var((a,i),_)) = Var((a,i), id_type(ident_tab, a))
	| add (Abs(a,_,body)) = Abs(a, id_type(ident_tab, a), add body)
	| add (f$t) = add f $ add t
  in  add  end;


(*Infer types for term t using tables.  Check that t has type T. *)
fun infer_term (const_tab, ident_tab, T, t) = 
    let val u = add_types (const_tab, ident_tab) t;
	val (U,tye) = infer1 ([], u, [])
    in  inst_term (unify(T, U, tye)) u  
      handle TUNIFY => raise TYPE
	("Term does not have expected type", [T, U], [inst_term tye u])
    end;

end;


