(*  Title: 	rules
    Author: 	Lawrence C Paulson, Cambridge University Computer Laboratory
    Copyright   1986  University of Cambridge

  the abstract types "theory" and "rule"

  RENAME GOAL AS SENTENCE
  new_theory: verify arities in definitions in signature
  security requires an abstype of certified terms
    representing  sign*term*arity*int [max index]
    should be the result of rep_rule
    and the arguments to constrain, trivial (gets rid of triv_of_prem)
  equivalent_rule: allow renaming of vars and params,
    Also helpful in "batch" files that derive rules
*)


fun maxidx_of_term (Const _) = 0
  | maxidx_of_term (Bound _) = 0
  | maxidx_of_term (Param ((_,idx), _)) = idx 
  | maxidx_of_term (Var ((_,idx), _)) = idx
  | maxidx_of_term (Abs (_,_,body)) = maxidx_of_term body
  | maxidx_of_term (rator $ rand) =
	max [maxidx_of_term rator,  maxidx_of_term rand];


fun maxidx_of_terms []  = 0
  | maxidx_of_terms tms = max (map maxidx_of_term tms);


fun maxidx_of_tpairs [] = 0
  | maxidx_of_tpairs ((t,u)::tus) =
      max[ maxidx_of_term t, maxidx_of_term u, maxidx_of_tpairs tus ];


(*The dependency list is normalized by normalizing each dependency 
    and keeping any variables OR PARAMETERS.
  Parameters were previously omitted, but are links in the dependency chain.
  Also, generalization turns these parameters into variables.  
  Even if dependency list is empty, must record that param was constrained.*)  
fun norm_depolist env =
  let fun normd (deps) = pvars_of_terms (map (norm_term env) deps)
  in  xmap normd end;


(*Accumulates the params in the term *)
fun add_term_params (tm, params) : arity xname_olist = case tm of
    Param (name,ary) => xinsert ((name,ary), params)
  | Abs (_,_,body)   => add_term_params(body,params)
  | rator $ rand     => add_term_params (rator, add_term_params(rand, params))
  |     _            => params;


(*Collect from ol those pairs whose keys appear in keyol.
  Contents part of keyol is ignored.*)
fun filter_olist (keyol: 'a xname_olist,  ol: 'b xname_olist) : 'b xname_olist =
  let fun add1 ((key,x), ol) =
	    case xsearch (keyol, key) of
		None   => ol
	      | Some _ => xinsert_new ((key,x), ol)
  in  itlist_right add1 (alist_of_olist ol, null_olist)  end;


(*Generalize a term: replace params by uvars,
    excepting params for which there are dependencies*)
fun gen_term depol : term -> term =
  let fun gen (tm as Param (name,ary)) = 
	    (case xsearch(depol,name) of
		None   => Var(name,ary) 
	      | Some _ => tm)
        | gen (Abs (name,ary,body)) = Abs (name,ary,gen body)
        | gen (rator $ rand) = gen rator  $  gen rand
        | gen tm = tm
  in  gen  end;


(*  [x1,...,xn]  ----->   ([x1,...,x(i-1)],  xi,  [x(i+1),...,xn])  *)
fun triple (i, xs: 'a list) : 'a list * 'a * 'a list =
  case nth_tail (i-1, xs) of
    x::xs' => (front(i-1, xs), x, xs')
  | []     => raise list with "triple";


val Agoal = Ground "goal";


abstype rule = Rule of
		{sign: signat,  stamp: int,  name: string,
		 tpairs: (term*term) list,  maxidx: int,
		 depol: (term list) xname_olist,
		 prems: term list,  concl: term}
and theory = Theory of
		{sign: signat,  stamp: int,  rules: rule symbol_table}
with
  exception rule : string * int * rule list;

  fun rep_rule (Rule x) = x;

  (*Resulting rule is equivalent to old one with respect to definitions.
    Tried writing a clever alphaconv test that expanded definitions lazily;
	was too complex.
    The argument "ids" avoids expanding all definitions.  *)
  fun by_definition (ids: string list,
		     new_prems: term list,
		     new_concl: term,
		     rl: rule)  :  rule =
    let val Rule{sign,name, stamp, tpairs, maxidx, depol, prems, concl} = rl;
	val defs = defs_of_sign sign ids;
	fun subst tm = norm_term null_env (unfold_term(defs,tm));
	fun aconvs([],[]) = true
	  | aconvs(t::ts, u::us) =
	      alphaconv(subst t, subst u) andalso aconvs(ts,us)
	  | aconvs _ = raise rule with
		("by_definition: number of prems", length new_prems, [rl])
    in  if  aconvs (new_concl::new_prems, concl::prems)
	then  Rule{sign = sign, stamp = stamp, maxidx = maxidx,
		name = name, tpairs = tpairs,
		prems = new_prems, concl = new_concl, depol= depol}
	else raise rule with ("by_definition: mismatch", length new_prems, [rl])
    end;


  (*The rule P/P in a theory (NO DEPENDENCIES)  *)
  fun trivial (Theory{sign,stamp,...}, tm:term) : rule = 
        let val concl = check_arity_term Agoal (norm_term null_env tm)
	in Rule{sign = sign, name = "trivial", stamp = stamp, tpairs = [], 
		maxidx = maxidx_of_term concl, depol= null_olist,
		prems = [concl], concl = concl}
	end;

  (*make a trivial rule for a given premise, for subgoaling*)
  fun triv_of_prem (rl,pno) =
	let val Rule{sign, stamp, tpairs, maxidx, prems, ...} = rl;
	    val prem = nth_elem (pno-1,prems)
	in Rule{sign = sign, stamp = stamp, tpairs = tpairs, maxidx = maxidx,
		name = "trivial", depol= null_olist,
		prems = [prem], concl = prem}
	end
	handle list => raise rule with ("triv_of_prem", pno, [rl]);


  (*Increment variables and parameters of rule by "inc"*)
  fun standardize 0 rl = rl
    | standardize inc rl =
	let val Rule{sign,name,stamp,tpairs,maxidx,depol,prems,concl} = rl
	in  Rule{sign=sign, maxidx= maxidx+inc, stamp=stamp, 
		 depol= incr_indexes_olist (map (incr_indexes inc),  inc) depol,
		 name= name ^ "'",
		 tpairs= map (pairself (incr_indexes inc)) tpairs, 
		 prems= map (incr_indexes inc) prems, 
		 concl= incr_indexes inc concl}
	end;
	

  (*Replace free parameters by uvariables*)
  fun generalize rl =
	let val Rule{sign,name,stamp,tpairs,maxidx,depol,prems,concl} = rl
	in  Rule{sign=sign, maxidx= maxidx, stamp=stamp, 
		 depol= xmap (map (gen_term depol)) depol,
		 name= "",
		 tpairs= map (pairself (gen_term depol)) tpairs, 
		 prems= map (gen_term depol) prems, 
		 concl= gen_term depol concl}
	end;
	

  (*Remove all constraints on rule
	by unification followed by trivial substitution.*)
  fun unify_constraints rl : rule sequence =
	let val Rule{sign,name,stamp,tpairs,maxidx,depol,prems,concl} = rl;
   	    fun newrl (env as Envir{maxidx,depol,...}) =
		    Rule{sign=sign, name=name, stamp=stamp, maxidx= maxidx,
			 depol= norm_depolist env depol, tpairs= [],
			 prems= map (norm_term env) prems, 
			 concl= norm_term env concl};
	    val env = Envir{maxidx=maxidx, asol= null_olist, depol=depol}
	in  maps newrl (smash_unifiers(sign, env, tpairs))  end;
	

  (*Collapse identical premises in rule.
    Delete unused parameters;  set maxidx as small as possible.*)
  fun merge_premises rl =
	let val Rule{sign,name,stamp,tpairs,maxidx,depol,prems,concl} = rl;
	    val tms = (map fst tpairs) @ (map snd tpairs) @ (concl::prems);
	    val maxi = maxidx_of_terms tms;
	    val usedol = itlist_right add_term_params (tms, null_olist)
	in Rule{sign=sign, maxidx= maxi, stamp=stamp, name= name,
		tpairs= tpairs, concl= concl,
		depol=filter_olist (usedol,depol),
		prems= tdistinct1 ([], prems)}
	end;
	

  (*compose two rules: unification but no standardizing of rla
    If tpairs=[], could check that new envir is a unifier.*)
  fun compose_to_sequence (rlb, pno, rla, rl_sequence) : rule sequence =
    let val Rule{sign, stamp=stb, tpairs=tpb, maxidx=maxb, depol=depa,
                 prems=QS, concl=R,...} = rlb
	and Rule{stamp=sta, tpairs=tpa, maxidx=maxa, depol=depb,
		 prems=PS, concl=Q,...} = rla;
	val (QS1,QI,QS2) = triple (pno,QS)
	      handle list => raise rule with("compose: number", pno, [rlb,rla]);
	fun newrl (env as Envir{maxidx,depol,asol}, tpairs) =
	      case alist_of_olist asol of  (*avoid wasted normalizations*)
		[] => Rule{sign=sign, name="", stamp=stb, maxidx= maxidx,
			   depol= norm_depolist env depol, tpairs= tpairs,
			   prems= QS1 @ PS @ QS2, concl= R}
	      | ((_,idx),_) :: _ =>
		  if idx>maxb then (*no assignments in rlb*)
		    Rule{sign=sign, name="", stamp=stb, maxidx= maxidx,
			 depol= norm_depolist env depol, 
			 tpairs= normpairs env tpairs,
			 prems= QS1 @ (map (norm_term env) PS) @ QS2, 
			 concl= R}
		  else (*normalize the new rule fully*)
		    Rule{sign=sign, name="", stamp=stb, maxidx= maxidx,
			 depol= norm_depolist env depol, 
			 tpairs= normpairs env tpairs,
			 prems= map (norm_term env) (QS1 @ PS @ QS2), 
			 concl= norm_term env R};
	val env = Envir{maxidx=max[maxa,maxb], asol= null_olist,
	       		 depol= xmerge_olists union_aterms (depa,depb)}
    in 	if sta<>stb then raise rule with ("compose: theories",0, [rlb,rla])
  	else mapp_sequence newrl
	     (unifiers(sign, env, (Q,QI)::tpa@tpb)) rl_sequence
    end;


  (*Constrain rule by additional pairs of terms, as additional equalities.
    The next unification takes these into account.
    Simple instantiations like ?A->Nat are a special case;
	unification is needed to prevent circular instantiations.
    WARNING: cannot be used to rename variables!  
	Flex-flex constraints have little effect.*)
  fun constrain (new_tpairs: (term*term) list) rl : rule = 
    let val Rule{sign, name, stamp, tpairs, maxidx,depol,prems,concl} = rl
    in  if forall (fn (t,u)=> arity_of t = arity_of u) new_tpairs 
	then  Rule{sign = sign, stamp = stamp,
		maxidx = max [maxidx, maxidx_of_tpairs new_tpairs],
		name = "", tpairs = new_tpairs@tpairs,
		prems = prems, concl = concl, depol= depol}
	else raise rule with ("constrain", length new_tpairs, [rl])
    end;
	

  fun get_rule (Theory{rules,...}) rname =
      case slookup(rules,rname) of
          Some rl => rl
        | None => raise rule with ("No such rule: "^rname,0,[]);


  (* Reading rules : each has name, dependencies, premises, conclusion 
    should check that dependencies are declared for exactly the parameters
    appearing in the rule, and refer only to variables in the rule!*)
  local val next = ref 0 in
  fun new_theory sign srls : theory =
     let val Signat{lextab,...} = sign;
	 fun read_rule (name, sdeps, sprems, sconcl) : string*rule =
               let val prems = map (read_theorem sign) sprems;
                   val concl = read_theorem sign sconcl
               in  (name,
                       Rule{sign= sign, name= name, stamp= !next, tpairs= [], 
                         maxidx = max (map maxidx_of_term (concl::prems)),
                         depol= olist_of_alist (map (read_deps lextab) sdeps),
                         prems=prems, concl=concl})
               end     in (next:= !next + 1;
	   Theory {sign=sign, stamp= !next, 
	   	   rules= symtab_of_alist (map read_rule srls)})
       end
  end (*local*);


  fun rep_theory (Theory x) = x
  end;


fun prems_of_rule rl : term list =
  let val {prems,...} = rep_rule rl  in  prems  end;



(*Unfold definitions in all premises of a rule.*)
fun unfold_def_in_prems (ids: string list) rl : rule =
  let val {sign, prems, concl,...} = rep_rule rl;
      val defines' = defs_of_sign sign ids
      fun unfoldit tm = norm_term null_env (unfold_term (defines', tm))
  in  by_definition (ids, map unfoldit prems, concl, rl)  end;


(*Fold definitions in all premises of a rule.*)
fun fold_def_in_prems (ids: string list) rl : rule =
  let val {maxidx, sign, prems, concl,...} = rep_rule rl;
      val defines' = defs_of_sign sign ids
      fun foldit tm = norm_term null_env
		        (fold_term (maxidx, defines', tm))
  in  by_definition (ids, map foldit prems, concl, rl)  end;


(*Fold definitions in premises AND CONCLUSION of a rule.*)
fun fold_def (ids: string list) rl : rule =
  let val {maxidx, sign, prems, concl,...} = rep_rule rl;
      val defines' = defs_of_sign sign ids
      fun foldit tm = norm_term null_env
		        (fold_term (maxidx, defines', tm))
  in  by_definition (ids, map foldit prems, foldit concl, rl)  end;


(*constrain using terms read from strings*)
fun read_constrain (stpairs: (string*string) list) rl =
    let val {sign,...} = rep_rule rl
    in  constrain (map (read_termpair sign) stpairs) rl  end;
	

(*simple function for composing two rules*)
fun compose (rlb,pno,rla) : rule sequence = 
  compose_to_sequence (rlb, pno, rla, null_sequence);


(*Resolution of a rule "b" with any list of rules "a"
  with rule"a" above:  rla/rlb,    standardizes rla*)
fun resolves (rlb,pno,[]) = null_sequence
  | resolves (rlb,pno,rla::rlsa) = sequenceof (fn()=>
      let val {maxidx,...} = rep_rule rlb
      in  spull(compose_to_sequence (rlb, pno, standardize (maxidx+1) rla, 
      		     resolves (rlb,pno,rlsa)))
      end);


(*Resolution, checking that only one resolvent is produced*) 
fun resolve (rlb,pno,rlsa) : rule =
  case chop_sequence (2, resolves (rlb,pno,rlsa)) of
      ([rl],_) => rl
    |      _   => raise rule with ("resolve", pno, rlb::rlsa);
  

(*resolution of first premise: an infix composition operator*)
infix RES;
fun (rlb RES rla) = resolve(rlb,1,[rla]);


(*resolution of nth premise*)
infix RESN;
fun (rlb RESN (pno,rla)) = resolve(rlb,pno,[rla]);


(*iterated resolution producing a single resolvent*)
val resolvelist: rule * (int * rule list)list -> rule =
    itlist_left (fn (rlb, (pno,rlsa)) =>  resolve (rlb,pno,rlsa));


(*simple version of unify_constraints: expect precisely one result*)
fun smash_constraints rl : rule =
  case chop_sequence (2, unify_constraints rl) of
      ([rl'],_) => rl'
    | _ => raise rule with ("smash_constraints", 0, [rl]);


(*Clean up a derived rule for later use*)
fun tidyrule rl = generalize (merge_premises (smash_constraints rl));

  
(*Predicate: rule is a theorem scheme?  --has no premises*)
fun is_theorem rl : bool = null(prems_of_rule rl);

(*Rule has no more than n premises?*)
fun has_prems n rl : bool =  length (prems_of_rule rl) <= n;


(*print dependencies "deps" of the parameter "name"*)
fun print_dep sign (name,deps) : unit = case deps of
    []   => ()
  | _::_ => (prs (string_of_xname name ^ ">");
             seq (fn t => (prs" ";  print_term sign t)) deps;  prs"\n");


fun print_rule rl : unit =
  let val {sign, tpairs, name, prems, concl, depol,...} = rep_rule rl;
      fun pspace Q = (print_theorem sign Q;  prs"    ")  
  in if name="" then ()  else prs (name^"\n");
     case prems of
	 [] => ()
       | _::_ => (seq pspace prems;
		  prs "\n---------------------------------\n");  
     print_theorem sign concl;  prs"\n";  
     seq (print_tpair sign) tpairs;
     seq (print_dep sign) (alist_of_olist depol)
  end;


(*Print rule in "goal style" with premises as numbered subgoals,
  Does not print dependencies because they are a nuisance*)
fun print_goal_rule rl : unit =
  let val {sign, tpairs, name, prems, concl,...} = rep_rule rl;
      fun pgoals (_, []) = ()
	| pgoals (n, Q::prems') =
	     (prs ("  "  ^  string_of_int n  ^  ". ");  
	      print_theorem sign Q;  prs"\n";
	      pgoals (n+1,prems'))  
  in print_theorem sign concl;  prs"\n";  pgoals (1,prems);  
     seq (print_tpair sign) tpairs
  end;


val print_rules: rule list -> unit = print_list_ln print_rule;


fun rtrivial (thy,s) = 
  let val {sign,...} = rep_theory thy
  in  trivial (thy, read_theorem sign s)  end;


