(*  Title: 	Types and Sorts
    Author:	Tobias Nipkow & Lawrence C Paulson

Type inference code used to be part of sign.
*)


signature TYPE =
sig
  structure Symtab:SYMTAB
  type type_sig;
  val defaultS: type_sig -> sort
  val extend: type_sig * (string * string list)list *
	      (string list * (string list * string))list *
	      (string list * (string list * string))list -> type_sig
  val freeze: term * indexname list -> term
  val infer_types: type_sig * typ Symtab.table * (string -> typ option) *
		   (string -> sort option) * (typ -> string) * typ * term
		   -> term * (indexname*typ)list
  val inst_typ_tvars: type_sig * (indexname * typ)list -> typ -> typ
  val inst_term_tvars: type_sig * (indexname * typ)list -> term -> term
  val logical_type: type_sig -> string -> bool
  val merge: type_sig * type_sig -> type_sig
  val tsig0: type_sig
  val sort_of_string: type_sig -> string -> sort
  val string_of_sort: type_sig -> sort -> string
  val type_errors: type_sig * (typ->string) -> typ * string list -> string list
  val typ_instance: type_sig * typ * typ -> bool
  val typ_match: type_sig -> (indexname*typ)list * (typ*typ) ->
		 (indexname*typ)list
  val unify: type_sig -> (typ*typ) * (indexname*typ)list -> (indexname*typ)list
  val varifyT: typ -> typ
  val varify: term * string list -> term
  exception TUNIFY
  exception TYPE_MATCH;
end;

functor TypeFun(structure Symtab:SYMTAB and Syntax:SYNTAX) : TYPE =
struct
structure Symtab = Symtab

(* Miscellany *)

fun commas sorts = space_implode ", " o map (fn S => nth_elem(S,sorts));

(* x_prod [l1,..,ln] = l1 x ... x ln *)
fun x_prod ll =
    let fun it_prod(l,ll) = flat(map (fn x => map (fn l => x::l) ll) l)
    in foldr it_prod (ll,[[]]) end;

(* Array Manipulation *)

fun for(ub:int, p: int -> unit) : unit =
    let fun step(i) = if i>ub then () else (p(i); step(i+1));
    in step 0 end;

type 'a matrix = 'a Array.array Array.array;

fun init_matrix(n,v) =
	Array.arrayoflist(map Array.array (replicate (n+1) (n+1,v)));

fun extend_matrix(a,k,v) =
let val m = Array.length a - 1;
    val b = init_matrix(m+k,v);
    fun cp(i,j) = Array.update(Array.sub(b,i),j,Array.sub(Array.sub(a,i),j));
in for(m,fn i => for(m, fn j => cp(i,j))); b end;

infix subarray;
fun a1 subarray a2 =
let val n1 = Array.length a1 and n2 = Array.length a2;
    fun eq i = if i >= n1 then true
	       else Array.sub(a1,i)=Array.sub(a2,i) andalso eq(i+1)
in n1 <= n2 andalso eq 0 end;

infix submatrix;
fun m1 submatrix m2 =
let val n1 = Array.length m1 and n2 = Array.length m2;
    fun eq i = if i >= n1 then true
	       else (Array.sub(m1,i) subarray Array.sub(m2,i)) andalso eq(i+1);
in n1 <= n2 andalso eq 0 end;


(* ORDER-SORTED TYPE SIGNATURE *)

type domain = sort list;
type arity = domain * sort;

type type_sig =
   {sorts: string list,
    subsort: bool matrix,
    inf: sort option matrix,
    coreg: (string * domain)list Array.array,
    types: (string * arity list) list};


val tsig0:type_sig =
	{sorts = ["any"],
	 subsort = init_matrix(0,true),
	 inf = init_matrix(0,Some(0)),
	 coreg = Array.arrayoflist[[]],
	 types = []};

fun undcl_sort(s) = error("Sort " ^ s ^ " has not been declared");

fun undcl_type(c) = "Undeclared type: " ^ c;
fun undcl_type_err(c) = error(undcl_type(c));

fun out_of_range(S,max) = "Sort " ^ string_of_int S ^
			  " out of range 0 .. " ^ string_of_int max;

fun sort_of_string ({sorts,...}:type_sig) s =
	find(s,sorts) handle LIST _ => undcl_sort(s);

fun string_of_sort({sorts,...}:type_sig) S = nth_elem(S,sorts)
	handle LIST _ => error(out_of_range(S,length(sorts)-1))

(* The last declared sort is the default sort *)
fun defaultS({sorts,...}:type_sig) = length(sorts)-1;

fun logical_type(tsig as {subsort,types,...}:type_sig) c =
let val logS = sort_of_string tsig "logic";
    fun is_log S = Array.sub(Array.sub(subsort,S),logS);
in case assoc(types,c) of
     Some(ars) => exists (is_log o snd) ars
   | None => undcl_type_err(c)
end;

fun lew a (w1,w2) = forall (fn(S1,S2)=>Array.sub(Array.sub(a,S1),S2)) (w1~~w2);

fun lbds a (S1,S2) =
let fun lbd(S,l) = if S<0 then l else
	if Array.sub(Array.sub(a,S),S1) andalso Array.sub(Array.sub(a,S),S2)
	then lbd(S-1,S::l) else lbd(S-1,l)
    in lbd(Array.length a - 1,[]) end;

fun glbsw a w1 w2 = x_prod (map (lbds a) (w1~~w2));

fun minima a xs =
    let fun add_min(x,ms) =
		if exists (fn m => Array.sub(Array.sub(a,m),x)) ms then ms
		else x::(filter_out (fn m => Array.sub(Array.sub(a,x),m)) ms)
    in foldr add_min (xs,[]) end;

fun maxima a xs =
    let fun add_max(x,ms) =
		if exists (fn m => Array.sub(Array.sub(a,x),m)) ms then ms
		else x::(filter_out (fn m => Array.sub(Array.sub(a,m),x)) ms)
    in foldr add_max (xs,[]) end;

fun maximaw a ws =
    let fun add_max(w,ms) =
		if exists (apl(w,lew a)) ms then ms
		else w::(filter_out (apr(lew a,w)) ms)
    in foldr add_max (ws,[]) end;

fun cod_above(a,w,ars) = map snd (filter (fn (w',_) => lew a (w,w')) ars);

fun dom_below(a,ars,S) =
	map fst (filter (fn (_,S') => Array.sub(Array.sub(a,S'),S)) ars);


(* DOWNWARD COMPLETENESS *)
(* every two sorts either have no or a greatest lower bound *)

fun dc_err(sorts,S1,S2,Ss) = error(
"Your sort structure is not downward complete:\n" ^
"Sorts " ^ nth_elem(S1,sorts) ^ " and " ^ nth_elem(S2,sorts) ^
" have more than one ``infimum'': {" ^ commas sorts Ss ^ "}");

fun downward_complete(sorts,a: bool matrix) =
let val maxS = Array.length(a)-1;
    val inf = init_matrix(maxS,None:int option);
    fun upd(S1,S2,S) = Array.update(Array.sub(inf,S1),S2,Some(S));
    fun dc2(S2,S1) = if S2>maxS then () else
		case maxima a (lbds a (S1,S2)) of
			[] => dc2(S2+1,S1)
		   | S::[] => (upd(S1,S2,S); upd(S2,S1,S); dc2(S2+1,S1))
		   |    Ss => dc_err(sorts,S1,S2,Ss);
    fun dc1(S1) = if S1>maxS then () else (dc2(S1+1,S1); dc1(S1+1));
in for(maxS,fn S => upd(S,S,S)); dc1(0); inf end;

(* REGULARITY - every type must have a least sort *)
(* based on Schmidt-Schauss: non-regular if f:(w1)s1 and f:(w2)S2, such that
   1. neither s1 <= S2 nor S2 <= s1,
   2. w <= w1 and w <= w2,
   3. there is no f:(w')s' such that w <= w' and s' <= s1 and s' <= S2
*)
fun regularity_errs(a,types) =
let fun reg(f,arities) =
	let fun reg1((w1,S1)::as1) =
		let fun reg2((w2,S2)::as2) =
			let fun test(w::ws) =
				let fun less(w',s') = lew a (w,w') andalso
					Array.sub(Array.sub(a,s'),S1) andalso
					Array.sub(Array.sub(a,s'),S2)
				in if exists less arities then test(ws)
				   else [(f,w,S1,S2)] end
			      | test([]) = []
			in if Array.sub(Array.sub(a,S1),S2) orelse
			      Array.sub(Array.sub(a,S2),S1)
			   then reg2(as2) else case test(glbsw a w1 w2) of
				[] => reg2(as2) | err => err end
		      | reg2([]) = []
		in case reg2(as1) of [] => reg1(as1) | err => err end
	      | reg1([]) = []
	in reg1(arities) end
in flat(map reg types) end;

fun reg_err sorts (f,w,S1,S2) =
"The type " ^ f ^ "(" ^ commas sorts w ^ ") has the sorts " ^
nth_elem(S1,sorts) ^ " and " ^ nth_elem(S2,sorts) ^ " but no least sort.";

fun regularity(sorts,a,types) = case regularity_errs(a,types) of
	[] => () |
	errs => error("Your sort structure is not regular because:\n" ^
		      implode (map (reg_err sorts) errs));


(* COREGULARITY *)

fun coreg_errs(errs,sorts) =
let fun add_err(txt,(f,S)) = txt ^
	"There is more than one maximal w such that " ^ f ^
	": w -> s and s <= " ^ nth_elem(S,sorts) ^ "\n";
in if null errs then () else error(foldl add_err ("",errs)) end;

fun coregular(sorts,a,decls) =
let datatype Result = None | One of string * domain | Err of string * sort;
    fun ok(One(f,w)) = [(f,w)] | ok(_) = [];
    fun error(Err(f,S)) = [(f,S)] | error(_) = [];
    fun below(S) =
	let fun f_below(f,arities) =
		case maximaw a (dom_below(a,arities,S)) of
		[] => None | w::[] => One(f,w) | _ => Err(f,S)
	in map f_below decls end;
    val resss = map below (0 upto (Array.length(a)-1));
    val unit = coreg_errs(flat(map (flat o map error) resss),sorts);
in Array.arrayoflist(map (flat o map ok) resss) end;


fun least_cod_above(a,w,ars) = case minima a (cod_above(a,w,ars)) of
	S::[] => S | _ => raise TYPE("",[],[]);

fun least_sort(tsig as {subsort,types,...}:type_sig) =
let fun ls(T as Type(a,Ts)) =
	let val ars = case assoc(types,a) of Some(ars) => ars
			 | None => raise TYPE(undcl_type a,[T],[]);
	in least_cod_above(subsort,map ls Ts,ars)
	   handle TYPE _ => raise TYPE("Type ill formed.", [T],[])
	end
      | ls(TFree(a,S)) = S
      | ls(TVar(a,S)) = S
in ls end;

fun has_sort(tsig as {subsort,...}:type_sig,T,s) =
	Array.sub(Array.sub(subsort,least_sort tsig T),s);

fun check_has_sort(tsig as {sorts,subsort,...}:type_sig,T,S) =
	if Array.sub(Array.sub(subsort,least_sort tsig T),S) then ()
	else raise TYPE("Type not of sort " ^ nth_elem(S,sorts),[T],[])

(*Instantiation of type variables in types *)
fun inst_typ_tvars(tsig,tye) =
    let fun inst(Type(a,Ts)) = Type(a, map inst Ts)
	  | inst(T as TFree _) = T
	  | inst(T as TVar(v,S)) = (case assoc(tye,v) of
		  None => T | Some(U) => (check_has_sort(tsig,U,S); U))
    in inst end;

(*Instantiation of type variables in terms *)
fun inst_term_tvars(tsig,tye) = map_term_types (inst_typ_tvars(tsig,tye));

exception TYPE_MATCH;

(* Typ matching
   typ_match(ts,s,(U,T)) = s' <=> s'(U)=T and s' is an extension of s *)
fun typ_match tsig =
let fun tm(subs, (TVar(v,S), T)) = (case assoc(subs,v) of
		None => ( (v, (check_has_sort(tsig,T,S); T))::subs
			handle TYPE _ => raise TYPE_MATCH )
	      | Some(U) => if U=T then subs else raise TYPE_MATCH)
      | tm(subs, (Type(a,Ts), Type(b,Us))) =
	if a<>b then raise TYPE_MATCH
	else foldl tm (subs, Ts~~Us)
      | tm(subs, (TFree(x), TFree(y))) =
	if x=y then subs else raise TYPE_MATCH
      | tm _ = raise TYPE_MATCH
in tm end;

fun typ_instance(tsig,T,U) = let val x = typ_match tsig ([],(U,T)) in true end
			     handle TYPE_MATCH => false;


(* EXTENDING AND MERGIN TYPE SIGNATURES *)

fun not_ident(s) = error("Must be an identifier: " ^ s);

fun twice(a) = error("Type constructor " ^a^ " has already been declared.");

(*Is the type valid? Accumulates error messages in "errs".*)
fun type_errors (tsig as {sorts,types,subsort,...}:type_sig, string_of_typ)
		(T,errs) =
let val maxS = Array.length(subsort)-1;
    fun sort_err(S,errs) =
	if 0 <= S andalso S <= maxS then errs
	else (out_of_range(S,maxS)) :: errs
    fun errors(Type(c,Us), errs) =
	let val errs' = foldr errors (Us,errs)
	in case assoc(types,c) of
	     None => (undcl_type c) :: errs
	   | Some(ars) => if length(fst(hd ars))=length(Us) then errs'
		else ("Wrong number of arguments: " ^ c) :: errs
	end
      | errors(TFree(_,S), errs) = sort_err(S,errs)
      | errors(TVar(_,S), errs) = sort_err(S,errs);
in case errors(T,[]) of
     [] => ((least_sort tsig T; errs)
	    handle TYPE(_,[U],_) => ("Ill-formed type: " ^ string_of_typ U)
				    :: errs)
   | errs' => errs'@errs
end;

fun add_sorts a =
let fun add sorts s =
    let val i = length sorts;
	val sa = Array.sub(a,i);
        fun upd s' =
	    let val j = find(s',sorts);
		fun trans(k) = if Array.sub(Array.sub(a,j),k)
				then Array.update(sa,k,true)
				else ()
	    in for(j, trans) end;
        fun iter(s'::ss) = if s' mem sorts then (upd s'; iter ss)
		else undcl_sort(s')
          | iter([]) = sorts@[s]
    in Array.update(sa,i,true); iter end;
    fun add_sort(sorts,(s,ges)) =
	if not(Syntax.is_identifier s) then not_ident s else
	if s mem sorts then error("Sort " ^ s ^ " declared twice.")
	else add sorts s ges;
in foldl add_sort end;


fun add_type sorts test =
let fun ix s = find(s,sorts) handle LIST _ => undcl_sort(s);
    fun add_ar(c, n, ars, (w,s)) =
	if length(w)=n then (map ix w,ix s) ins ars
	else error("Type constructor "^c^" has varying number of arguments.")
    fun add (ar as (w,_)) = fn (types,c) =>
	if not(Syntax.is_identifier c) then not_ident c else
	(test(c);
	 case assoc(types,c) of
	   Some(ars') =>
		let val ars'' = add_ar(c,length(fst(hd ars')),ars',ar)
		in overwrite(types,(c,ars'')) end
	 | None => (c,add_ar(c,length(w),[],ar)) :: types)
    fun adds(types,(cs,ar)) = foldl (add ar) (types,cs);
in adds end;

fun extend_sorts(sorts,subsort,inf,newsorts) =
if newsorts=[] then (sorts,subsort,inf) else
let val subsort' = extend_matrix(subsort,length(newsorts),false);
    val sorts' = add_sorts subsort' (sorts,newsorts);
    val inf' = downward_complete(sorts',subsort');
in (sorts',subsort',inf') end;

fun extend({sorts,subsort,inf,types,...}:type_sig,newsorts,otypes,newtypes) =
let val (sorts',subsort',inf') = extend_sorts(sorts,subsort,inf,newsorts);
    val old_types = map fst types;
    fun is_old(c) = if c mem old_types then () else undcl_type_err(c);
    fun is_new(c) = if c mem old_types then twice(c) else ();
    val types' = foldl (add_type sorts' is_old) (types, otypes);
    val types' = foldl (add_type sorts' is_new) (types',newtypes);
    val unit = regularity(sorts',subsort',types')
    val coreg' = coregular(sorts',subsort',types');
in {sorts=sorts', subsort=subsort', inf=inf', coreg=coreg', types=types'} end;


fun merge({sorts=sorts1, subsort=sub1, inf=inf1, coreg=coreg1,
	   types=types1}:type_sig,
	  {sorts=sorts2, subsort=sub2, inf=inf2, coreg=coreg2,
	   types=types2}:type_sig):type_sig =
let fun add_type((f,ars),types) = (case assoc(types,f) of
		Some(ars') => (f,ars union ars')::types
	      | None => (f,ars)::types);
    val types = foldr add_type (types1,types2);
in if (sorts1 prefix sorts2) andalso (sub1 submatrix sub2)
   then {sorts=sorts2, subsort=sub2, inf=inf2, types=types,
	 coreg=coregular(sorts2,sub2,types)}
   else if (sorts2 prefix sorts1) andalso (sub2 submatrix sub1)
   then {sorts=sorts1, subsort=sub1, inf=inf1, types=types,
	 coreg=coregular(sorts1,sub1,types)}
   else error("Cannot merge signatures with incompatible sort structures.")
end;


(**** TYPE INFERENCE ****)

(* 1. Freeze all TVars in constraints into TFrees using mark_free.
   2. Carry out type inference, possibly introducing TVars.
   3. Rename and freeze all TVars.
   4. Thaw all TVars frozen in step 1.
*)

(*Raised if types are not unifiable*)
exception TUNIFY;

val tyvar_count = ref(~1);

fun tyinit() = (tyvar_count := ~1);

fun new_tvar_inx() = (tyvar_count := !tyvar_count-1; !tyvar_count)

(* Generate new TVar.
   Name is arbitrary (because all TVars in type constraints have been frozen)
   Index new and negative to distinguish it from TVars generated from
   variable names (see id_type). *)
fun gen_tyvar(S) = TVar(("'a", new_tvar_inx()),S);

fun gen_opt_tyvar(Some(S)) = gen_tyvar(S)
  | gen_opt_tyvar(None) = raise TUNIFY;

(*Occurs check: type variable occurs in type?*)
fun occ (v, Type(_,Ts), tye) = exists (fn T => occ(v,T,tye)) Ts
  | occ (v, TFree _, _) = false
  | occ (v, TVar(w,_), tye) = if v=w then true
	else (case assoc(tye,w) of
	          None   => false
	        | Some U => occ(v,U,tye));


(*Chase variable assignments in tye.  
  If devar (T,tye) returns a type var then it must be unassigned.*) 
fun devar (T as TVar(v,_), tye) = (case  assoc(tye,v)  of
          Some U =>  devar (U,tye)
        | None   =>  T)
  | devar (T,tye) = T;


fun weaken(U,S,{subsort,inf,coreg,...}:type_sig,tye) =
    let fun wk(tye,(T,S)) = case devar(T,tye) of
	      Type(f,Ts) =>
		(case assoc(Array.sub(coreg,S),f) of
		   Some(w) => foldl wk (tye,Ts~~w)
		 | None => raise TUNIFY)
	    | TFree(_,S') => if Array.sub(Array.sub(subsort,S'),S) then tye
			     else raise TUNIFY
	    | TVar(v,S') => if Array.sub(Array.sub(subsort,S'),S) then tye
		else (v,gen_opt_tyvar(Array.sub(Array.sub(inf,S),S')))::tye
    in wk(tye,(U,S)) end;

(* Unification of types *)
(* Precondition: both types are well-formed w.r.t. type constructor arities *)

fun unify(tsig as {subsort,inf,...}:type_sig) =
let fun unif_var(T as TVar(v,S1), U as TVar(w,S2), tye) =
	if v=w then tye else
	if Array.sub(Array.sub(subsort,S1),S2) then (w,T)::tye else
	if Array.sub(Array.sub(subsort,S2),S1) then (v,U)::tye
	else let val nu = gen_opt_tyvar(Array.sub(Array.sub(inf,S1),S2))
	     in (v,nu)::(w,nu)::tye end
      | unif_var(T as TVar(v,S), U, tye) =
	if occ(v,U,tye) then raise TUNIFY else (v,U)::weaken(U,S,tsig,tye)
      | unif_var(T,U,tye) = unif((T,U),tye)
    and unif((T as TVar _,U), tye) = unif_var(devar(T, tye), devar(U,tye), tye)
      | unif((T,U as TVar _), tye) = unif_var(devar(U, tye), devar(T,tye), tye)
      | unif((Type(a,Ts),Type(b,Us)), tye) = if a<>b then raise TUNIFY
		else unif_types (Ts~~Us, tye)
      | unif((TFree(x),TFree(y)), tye) = if x=y then tye else raise TUNIFY
      | unif _ = raise TUNIFY
    and unif_types(TUs,tye) = foldr unif (TUs,tye)
in unif end;


(*Instantiation of type variables in types*)
(*Pre: instantiations obey restrictions! *)
fun inst_typ tye =
    let fun inst(Type(a,Ts)) = Type(a, map inst Ts)
	  | inst(T as TFree _) = T
	  | inst(T as TVar(v,_)) = (case assoc(tye,v) of
		  Some U => inst U
		| None   => T)
    in inst end;

(*Type inference for polymorphic term*)
fun infer1 tsig =
let fun i1(Ts, Const (_,T), tye) = (T,tye)
      | i1 (Ts, Free  (_,T), tye) = (T,tye)
      | i1 (Ts, Bound i, tye) = ((nth_elem(i,Ts) , tye)
      handle LIST _=> raise TYPE ("loose bound variable", [], [Bound i]))
      | i1 (Ts, Var (_,T), tye) = (T,tye)
      | i1 (Ts, Abs (_,T,body), tye) = 
	let val (U,tye') = i1(T::Ts, body, tye) in  (T-->U, tye')  end
     | i1 (Ts, f$u, tye) = 
	let val (U,tyeU) = i1 (Ts, u, tye);
	    val (T,tyeT) = i1 (Ts, f, tyeU)
	in (case T of
	      Type("fun",[T1,T2]) =>
		( (T2, unify tsig ((T1,U), tyeT)) handle TUNIFY =>
		  raise TYPE("type mismatch in application",
			     [inst_typ tyeT T, inst_typ tyeT U], [f$u]))
	    | TVar _ => 
		let val T2 = gen_tyvar(any)
		in (T2, unify tsig ((T, U-->T2), tyeT)) handle TUNIFY =>
		   raise TYPE("type mismatch in application",
			      [inst_typ tyeT T, inst_typ tyeT U], [f$u])
		end
	    | _ => raise TYPE("rator must have function type",
			      [inst_typ tyeT T, inst_typ tyeT U], [f$u]))
	end
in i1 end;


fun mark_free(Type(a,Ts)) = Type(a,map mark_free Ts)
  | mark_free(T as TFree _) = T
  | mark_free(TVar(v,S)) = TFree(Syntax.string_of_vname v, S);


(* Attach a type to a constant *)
fun type_const (a,T) = Const(a, incr_tvar (new_tvar_inx()) T);


(*Find type of ident.  If not in table then use ident's name for tyvar
  to get consistent typing.*)
fun id_type a = TVar(("'"^a,~1),any);
fun id_def_type(a,defT) =
	case defT a of Some T => mark_free T | None => id_type a;

fun constrain(term,T) = Const(Syntax.constrainC,T-->T) $ term;
fun constrainAbs(Abs(a,_,body),T) = Abs(a,T,body);


(*Attach types to a term.  Input is a "parse tree" containing dummy types.
  Leave in type of _constrain (essential for it to work!) *)
(* TVars in _constrain are frozen by turning them into new TFrees *)
(* Type constraints are translated and checked for validity wrt tsig *)
fun add_types (tsig, const_tab, defT, defS, string_of_typ) =
let val S0 = defaultS tsig;
    fun defS0 s = case defS s of Some S => S | None => S0;
    fun prepareT(typ) =
	let val T = Syntax.typ_of_term (defS0,sort_of_string tsig) typ;
	    val T' = mark_free T
	in case type_errors (tsig,string_of_typ) (T,[]) of
	     [] => T'
	   | errs => raise TYPE(cat_lines errs,[T],[])
	end
    fun add (Const(a,_)) =
            (case Symtab.lookup(const_tab, a) of
                Some T => type_const(a,T)
              | None => raise TYPE ("No such constant: "^a, [], []))
      | add (Bound i) = Bound i
      | add (Free(a,_)) =
            (case Symtab.lookup(const_tab, a) of
                Some T => type_const(a,T)
              | None => Free(a, id_def_type(a, defT)))
      | add (Var((a,i),_)) = Var((a,i), id_def_type(a, defT))
      | add (Abs(a,_,body)) = Abs(a, id_type a, add body)
      | add ((f as Const(a,_)$t1)$t2) =
	if a=Syntax.constrainC then constrain(add t1,prepareT t2) else
	if a=Syntax.constrainAbsC then constrainAbs(add t1,prepareT t2)
	else add f $ add t2
      | add (f$t) = add f $ add t
in  add  end;


(* Post-Processing *)


(*Instantiation of type variables in terms*)
fun inst_types tye = map_term_types (inst_typ tye);

(*Delete explicit constraints -- occurrences of "_constrain" *)
fun unconstrain (Abs(a,T,t)) = Abs(a, T, unconstrain t)
  | unconstrain ((f as Const(a,_)) $ t) =
      if a=Syntax.constrainC then unconstrain t
      else unconstrain f $ unconstrain t
  | unconstrain (f$t) = unconstrain f $ unconstrain t
  | unconstrain (t) = t;


(*Accumulates the TFrees and TVars in a typ, suppressing duplicates. *)
fun add_tfreevars(T,(fs,vs)) = (add_typ_tfrees(T,fs), add_typ_tvars(T,vs));

(*Collects the TFrees and TVars in a term, suppressing duplicates. *)
fun term_tfreevars t = it_term_types add_tfreevars (t,([],[]));

(* Turn all TVars except those in fixed into new TFrees *)
fun freeze(t,fixed) =
let val (fs,inSs) = term_tfreevars t;
    val inxs = (map fst inSs) \\ fixed;
    val vmap = inxs ~~ variantlist(map fst inxs, fs);
    fun free(Type(a,Ts)) = Type(a, map free Ts)
      | free(T as TVar(v,S)) =
	  (case assoc(vmap,v) of None => T | Some(a) => TFree(a,S))
      | free(T as TFree _) = T
in map_term_types free t end;


(* Thaw all TVars that were frozen in mark_free *)
fun thaw_tvars(Type(a,Ts)) = Type(a, map thaw_tvars Ts)
  | thaw_tvars(T as TFree(a,S)) = (case explode a of
	  "?"::"'"::vn => let val ((b,i),_) = Syntax.scan_varname vn
			  in TVar(("'"^b,i),S) end
	| _ => T)
  | thaw_tvars(T) = T;


fun restrict tye =
let fun clean(tye1, ((a,i),T)) =
	if i < 0 then tye1 else ((a,i),inst_typ tye T) :: tye1
in foldl clean ([],tye) end


(*Infer types for term t using tables.
  Check that t's type and T unify *)
fun infer_term (tsig, const_tab, defT, defS, string_of_typ, T, t) =
let val u = add_types (tsig, const_tab, defT, defS, string_of_typ) t;
    val (U,tye) = infer1 tsig ([], u, []);
    val uu = unconstrain u;
    val tye' = unify tsig ((T,U),tye) handle TUNIFY => raise TYPE
	("Term does not have expected type", [T, U], [inst_types tye uu])
    val Ttye = restrict tye' (* restriction to TVars in T *);
    val all = Const("",Type("",map snd Ttye)) $ (inst_types tye' uu)
		(* contains all exported TVars *);
    val vars = map fst (add_typ_tvars(T,[]));
    val Const(_,Type(_,Ts)) $ u'' =
		map_term_types thaw_tvars (freeze(all,vars));
in (u'', (map #1 Ttye) ~~ Ts) end;

fun infer_types args = (tyinit(); infer_term args);

(* Turn TFrees into TVars to allow types & axioms to be written without "?" *)
fun varifyT(Type(a,Ts)) = Type(a,map varifyT Ts)
  | varifyT(TFree(a,S)) = TVar((a,0),S)
  | varifyT(T) = T;

(* Turn TFrees except those in fixed into new TVars *)
fun varify(t,fixed) =
let val (fs,inSs) = term_tfreevars t;
    val fs' = fs \\ fixed;
    val fmap = fs' ~~ variantlist(fs', map (fst o fst) inSs)
    fun thaw(Type(a,Ts)) = Type(a, map thaw Ts)
      | thaw(T as TVar _) = T
      | thaw(T as TFree(a,S)) =
	  (case assoc(fmap,a) of None => T | Some b => TVar((b,0),S))
in map_term_types thaw t end;


end;
