(*  Title: 	sign
    Author: 	Lawrence C Paulson, Cambridge University Computer Laboratory
    Copyright   1989  University of Cambridge

PROBLEMS:  mapping between types and nonterminals

  the abstract types "sg" (signatures)
  and "cterm" (certified terms under a signature)

can "merge" wrongly raise exn TABLE??
*)


signature SIGN = 
sig
  structure Syntax: SYNTAX
  type sg and cterm
  val cterm_of: sg -> term -> cterm
  val extend: sg -> string -> string list * (string list * typ)list * Syntax.syntax -> sg
  val get_const: (string list * typ)list -> string -> term
  val merge: sg * sg -> sg
  val print_cterm: cterm -> unit
  val print_term: sg -> term -> unit
  val pure: sg
  val read_cterm: sg -> string * typ -> cterm
  val read_insts: sg -> (string*string*typ)list -> (cterm*cterm)list
  val rep_cterm: cterm -> {T: typ, t: term, sign: sg, maxidx: int}
  val rep_sg: sg-> {gnd_types: string list, const_tab: typ Syntax.Symtab.table, ident_tab: typ Syntax.Symtab.table, syn: Syntax.syntax, stamps: string ref list}
  val term_of: cterm -> term
  val type_assign: cterm -> cterm
end;


functor SignFun (Syntax: SYNTAX) : SIGN = 
struct
structure Syntax = Syntax;
structure Symtab = Syntax.Symtab;


(*Signatures of theories. *)
type sg = 
  {gnd_types: string list,	(*ground types*)
   const_tab: typ Symtab.table,	(*types of constants*)
   ident_tab: typ Symtab.table,	(*default types of identifiers*)
   syn: Syntax.syntax,		(*Parsing and printing operations*)
   stamps: string ref list	(*unique theory indentifier*)  };


fun rep_sg (args: sg) = args;


(*Concatenate messages, one per line, into a string*)
val cat_lines = implode o (map (apr(op^,"\n"))) o distinct;


(*Is the type valid?  Must be monotype.  Ground types must be in list.
  Accumulates error messages in "errs".*)
fun type_errors Ts (Ground a, errs) =
	if  a mem Ts  then  errs  
	else  ("Undeclared type: " ^ a) :: errs
  | type_errors Ts (Poly a, errs) = ("Type variable: " ^ a) :: errs
  | type_errors Ts (T-->U, errs) = type_errors Ts (T, type_errors Ts (U,errs));



(*Is constant a standard one with correct type, or present in table? *)
fun valid_const ctab ("==>",T) = (T = Aprop-->Aprop-->Aprop)
  | valid_const ctab ("all",T) = 
	(case T of (U--> _)--> _ => (T = (U-->Aprop)-->Aprop)
	  | _ => false)
  | valid_const ctab ("==",T) = 
	(case T of U--> (_-->_) => (T = U-->(U-->Aprop))
	  | _ => false)
  | valid_const ctab (a,T) = 
	(case Symtab.lookup(ctab, a) of
            Some U => T=U
          | _ => false);


(*Check a term for errors.  Are all constants and types valid in signature?
  Does not check that term is well-typed!*)
fun term_errors ({gnd_types,const_tab,...}: sg) = 
  let fun terrs (Const (a,T), errs) =
	    if  valid_const const_tab (a,T) then type_errors gnd_types (T,errs)
	    else  ("Illegal type for constant: " ^ a ^
		   ": " ^ string_of_type T) :: errs
  	| terrs (Free (_,T), errs) = type_errors gnd_types (T,errs)
  	| terrs (Var  ((a,i),T), errs) =
	    if  i>=0  then  type_errors gnd_types (T,errs)
	    else  ("Negative index for Var: " ^ a) :: errs
  	| terrs (Bound _, errs) = errs (*loose bvars detected by type_of*)
  	| terrs (Abs (_,T,t), errs) = type_errors gnd_types (T, terrs (t,errs))
  	| terrs (f$t, errs) = terrs(f, terrs (t,errs))
  in  terrs  end;


val pure_cpairs = 
  [(*equality, given with an arbitrary monotype*)
    ("==",	[Aprop,Aprop]--->Aprop),
    (*implication*)
    ("==>",	[Aprop,Aprop]--->Aprop),
    (*universal quantifier, given with an arbitrary monotype*)
    ("all",	(Aprop-->Aprop) --> Aprop),
    (*type constraint operator, given with an arbitrary monotype*)
    ("_constrain",	Aprop-->Aprop) ];


(*The pure signature*)
val pure : sg =
   {gnd_types= ["prop"], 
    const_tab= Symtab.st_of_alist (pure_cpairs, Symtab.null), 
    ident_tab= Symtab.null,  syn=Syntax.pure,  stamps= []};



(** The Extend operation **)


(*Check that all types mentioned in the list of declarations are valid.
  If errors found then raise exception.  Ts are ground types. *)
fun check_consts Ts [] = ()
  | check_consts Ts ((cs,T)::pairs) =
      (case type_errors Ts (T,[]) of
	  [] => check_consts Ts pairs
        | errs =>  error (cat_lines 
 	   (("Error in type of constants " ^ space_implode " " cs)
		  ::  errs)));


(*Extend a signature: may add types and constants.  Replaces syntax with "syn". 
  The "ref" in stamps ensures that no two signatures are 
  identical -- it is impossible to forge a signature.  *)
fun extend (sign: sg) signame (newtypes, const_decs, syn) : sg =
    let val {gnd_types, const_tab, ident_tab, stamps, ...} = sign;
	val types = gnd_types union newtypes;
	val unity = check_consts types const_decs  (*error checking*)
    in {gnd_types= types, 
	const_tab= Symtab.st_of_declist (const_decs, const_tab), 
	ident_tab= ident_tab,  syn=syn,
	stamps= ref signame :: stamps}
  end;


(*Get a constant symbol from const_decs*)
fun get_const const_decs =
    let val ctab = Symtab.st_of_declist (const_decs, Symtab.null);
	fun get a = case Symtab.lookup(ctab, a) of
	      Some T => Const(a,T)
	    | _ => raise TERM_ERROR ("get_const: "^a, [])
    in  get  end;


(** The Merge operation **)

(*Update table with (a,x) providing any existing asgt to "a" equals x. *)
fun update_eq ((a,x),tab) =
    case Symtab.lookup(tab,a) of
	None => Symtab.update((a,x), tab)
      | Some y => if x=y then tab 
	    else  raise TERM_ERROR ("Incompatible types for constant: "^a, []);

(*Combine tables, updating tab2 by tab1 and checking.*)
fun merge_tabs (tab1,tab2) = 
    Symtab.balance (itlist_right update_eq (Symtab.alist_of tab1, tab2));

(*Combine tables, overwriting tab2 with tab1.*)
fun smash_tabs (tab1,tab2) = 
    Symtab.balance (itlist_right Symtab.update (Symtab.alist_of tab1, tab2));


(*Merge two signatures.  Forms unions of tables.  Prefer sign1. *)
fun merge (sign1:sg, sign2:sg) =
    let val {gnd_types=gnd1, const_tab=ctab1, ident_tab=itab1,
		  stamps=stamps1, syn=syn1} = sign1
	and {gnd_types=gnd2, const_tab=ctab2, ident_tab=itab2,
		  stamps=stamps2, syn=syn2} = sign2
    in
    if stamps2 subset stamps1 then sign1
    else if stamps1 subset stamps2 then sign2
    else  (*neither is union already;  must form union*)
	   {gnd_types= gnd1 union gnd2,
	    const_tab= merge_tabs (ctab1, ctab2),
	    ident_tab= smash_tabs (itab1, itab2),
	    stamps= stamps1 union stamps2,
	    syn = Syntax.merge(syn1,syn2)}
    end;


(**** TYPE INFERENCE ****)


val tyvar_count = ref 0;

fun tyinit() = (tyvar_count := 0);

fun gen_tyvar() = 
    (tyvar_count := !tyvar_count+1;
     Poly ("." ^ string_of_int (!tyvar_count)));


(*Occurs check: type variable occurs in type?*)
fun occ (a, Ground _, tye) = false
  | occ (a, Poly b, tye) = 
	if a=b then true
	else (case assoc(tye,b) of
	          None   => false
	        | Some U => occ(a,U,tye))
  | occ (a, T-->U, tye) = occ(a,T,tye)  orelse  occ(a,U,tye);


(*Raised if types are not unifiable*)
exception TUNIFY;

(*Chase variable assignments in tye.  
  If devar (T,tye) returns a type var then it must be unassigned.*) 
fun devar (Poly a, tye) =
      (case  assoc(tye,a)  of
          Some U =>  devar (U,tye)
        | None   =>  Poly a)
  | devar (T,tye) = T;


(*Unification of types*)
fun unify_var (Poly a, U, tye) =  if U = Poly a  then  tye
         else if occ(a,U,tye) then  raise TUNIFY  else  (a,U)::tye
  | unify_var (T,U,tye) = unify(T,U,tye)
and unify (Poly a, T, tye) = unify_var (devar(Poly a, tye), devar(T,tye), tye)
  | unify (T, Poly a, tye) = unify_var (devar(Poly a, tye), devar(T,tye), tye)
  | unify (Ground a, Ground b, tye) =
        if a=b then tye  else  raise TUNIFY
  | unify (T1-->T2, U1-->U2, tye) =  
        unify (T2, U2, unify (T1, U1, tye))
  | unify _ =  raise TUNIFY;



(*Instantiation of type variables in types*)
fun inst tye (Ground a) = Ground a
  | inst tye (Poly a) = 
      (case  assoc(tye,a)  of
	  Some U =>  inst tye U
	| None   =>  Poly a)
  | inst tye (T-->U) = inst tye T --> inst tye U;

(*Instantiation of type variables in terms
    Delete explicit constraints -- occurrences of "_constrain" *)
fun inst_term tye (Const(a,T)) = Const(a, inst tye T)
  | inst_term tye (Free(a,T)) = Free(a, inst tye T)
  | inst_term tye (Var(v,T)) = Var(v, inst tye T)
  | inst_term tye (Bound i)  = Bound i
  | inst_term tye (Abs(a,T,t)) = Abs(a, inst tye T, inst_term tye t)
  | inst_term tye (Const("_constrain",_) $ t) = inst_term tye t
  | inst_term tye (f$t) = inst_term tye f $ inst_term tye t;


(*Type inference for polymorphic term*)
fun infer1 (Ts, Const (_,T), tye) = (T,tye)
  | infer1 (Ts, Free  (_,T), tye) = (T,tye)
  | infer1 (Ts, Bound i, tye) = ((nth_elem(i,Ts) , tye)
      handle LIST _=> raise TYPE ("loose bound variable", [], [Bound i]))
  | infer1 (Ts, Var (_,T), tye) = (T,tye)
  | infer1 (Ts, Abs (_,T,body), tye) = 
	let val (U,tye') = infer1(T::Ts, body, tye)
	in  (T-->U, tye')  end
  | infer1 (Ts, f$u, tye) = 
	let val (U,tyeU) = infer1 (Ts, u, tye);
	    val (T,tyeT) = infer1 (Ts, f, tyeU)
	in (case T of
	      T1-->T2 => (T2, unify(T1, U, tyeT))
	    | Poly _ => 
		let val T2 = gen_tyvar()
		in  (T2, unify(T, U-->T2, tyeT))  end
	    | _ => raise TYPE ("Rator must have function type",
				  [inst tyeT T, inst tyeT U], [f$u]))
	  handle TUNIFY => raise TYPE
	     ("type mismatch in application", 
	      [inst tyeT T, inst tyeT U], [f$u])
	end;


(*Attach a type to a constant.
  Polymorphic (built-in) constants require new tyvars*)
fun type_const (a,T) =
    if a = "=="  then  equals(gen_tyvar())  
    else if a = "all"  then  all(gen_tyvar())  
    else   Const(a,T);


(*Find type of ident.  If not in table then use ident's name for tyvar
  to get consistent typing.*)
fun id_type (itab, a) =
    case Symtab.lookup(itab, a) of
	Some T => T
      | None => Poly a;


(*Attach types to a term.  Input is a "parse tree" containing dummy types.
  Leave in type of _constrain (essential for it to work!) *)
fun add_types (const_tab, ident_tab) =
  let fun add (t as Const("_constrain",_)) = t
        | add (Const(a,_)) =
            (case Symtab.lookup(const_tab, a) of
                Some T => type_const(a,T)
              | None => raise TYPE ("No such constant: "^a, [], []))
        | add (Bound i) = Bound i
        | add (Free(a,_)) =
            (case Symtab.lookup(const_tab, a) of
                Some T => type_const(a,T)
              | None => Free(a, id_type(ident_tab, a)))
        | add (Var((a,i),_)) =  (*offset is part of name*)
	    Var((a,i), id_type(ident_tab, a ^ "." ^ string_of_int i))
        | add (Abs(a,_,body)) = Abs(a, id_type(ident_tab, a), add body)
        | add (f$t) = add f $ add t
  in  add  end;


(*Infer types for term t using tables.  Check that t has type T. *)
fun infer_term (const_tab, ident_tab, T, t) = 
    let val u = add_types (const_tab, ident_tab) t;
	val (U,tye) = infer1 ([], u, [])
    in  inst_term (unify(T, U, tye)) u  
      handle TUNIFY => raise TYPE
	("Term does not have expected type", [T, U], [inst_term tye u])
    end;


(**** CERTIFIED TERMS ****)


(*Certified terms under a signature, with checked typ and maxidx of Vars*)
datatype cterm = Cterm of {sign: sg,  t: term,  T: typ,  maxidx: int};

fun rep_cterm (Cterm args) = args;

(*Return the underlying term*)
fun term_of (Cterm{sign,t,T,maxidx}) = t;


(*Create a cterm by checking a "raw" term with respect to a signature*)
fun cterm_of sign t =
  case  term_errors sign (t,[])  of
      [] => Cterm{sign=sign, t=t, T= type_of t, maxidx= maxidx_of_term t}
    | errs => raise TERM_ERROR(cat_lines("Term not in signature"::errs), [t]);


(*Lexing, parsing, polymorphic typechecking of a term.*)
fun read_cterm (sign as {const_tab, ident_tab, syn,...}: sg) (a,T) =
  cterm_of sign
    (tyinit();
     infer_term (const_tab, ident_tab, T, Syntax.read syn T a))  
  handle TYPE (msg, Ts, _) =>
	   error ("Type checking error: " ^ msg ^ "\n"
		  ^ implode (map (apl("   ",op^) o string_of_type) Ts) ^ "\n")
       | TERM_ERROR (msg, _) => error ("Error: " ^  msg);


fun print_term ({syn,...}: sg) = Syntax.prin syn;

fun print_cterm (Cterm{sign,t,T,maxidx}) = print_term sign t;


(*Read an instantiation list: (var,cterm) pairs.  Note that the Var is created,
  not read, since the signature may contain different var/type assignments.*)
fun read_insts sign [] = []
  | read_insts sign ((sv,st,T)::l) =
      (case  Syntax.scan_varname (explode sv) of
	  ((a,i),[]) =>
	      (cterm_of sign (Var((a,i),T)), read_cterm sign(st,T)) 
		      :: read_insts sign l
	| (_,cs) => error ("Lexical error in Var.  Location: "
		     ^ implode cs ^ "\n"));


(*Add identifier/type pairs to a signature to facilitate type inference.
  The Isabelle version of LCF sticky types! *)


(*Replace the ident_tab component of a signature*)
fun new_ident_tab itab ({gnd_types, const_tab, stamps, syn, ...}: sg) =
  {gnd_types=gnd_types, stamps=stamps, syn=syn,
   const_tab=const_tab,  ident_tab = itab};

(*Find all identifier/type pairs in a term, accumulate in table*)
fun add_typairs (Free(a,T), itab) = Symtab.update((a,T), itab)
  | add_typairs (Var((a,_),T), itab) = Symtab.update((a,T), itab)
  | add_typairs (Abs(a,T,t), itab) =
		Symtab.update((a,T), add_typairs (t, itab))
  | add_typairs (f$t, itab) = add_typairs (t, add_typairs (f, itab))
  | add_typairs (_, itab) = itab;


(*Extend the signature of the cterm with all type assignments*)
fun type_assign (Cterm{sign,t,T,maxidx}) = 
    let val itab = Symtab.balance  (add_typairs (t, #ident_tab sign))
    in  Cterm{sign = new_ident_tab itab sign,    t=t, T=T, maxidx=maxidx}
    end;


end;
