(* extract.ml                                            (c) R.J.Boulton 1990 *)
(*----------------------------------------------------------------------------*)


structure TRS_struct: TRS_struct_sig =
struct

	(* Function to take a term and return a triple:     *)
	(* (<constants>,<free variables>,<bound variables>) *)
	
	(* Set union and set difference are used to aunit generating repetitions in   *)
	(* the lists derived => For function applications, the lists from the rator and *)
	(* the rand can simply be joined by set union => For abstractions, the bound    *)
	(* variable is removed from the free-variable list of the body, and is added  *)
	(* to the bound-variable list of the body =>                                    *)
	exception get_ids_FAIL;

	fun get_ids t =
	
	   (* : (term -> (term list * (term list * term list))) *)

	   if (Term.is_const t) then ([t],([],[]))
	   else if (Term.is_var t)   then ([],([t],[]))
	   else if (Term.is_abs t)   then
	     ( let val {Body=bv,Bvar=body} = Term.dest_abs t
	      in   let val (cl,(fvl,bvl)) = get_ids body
	          in  (cl,((subtract fvl [bv]),(union bvl [bv])))
		  end
	      end
	     )
	   else if (Term.is_comb t)  then
	     ( let val {Rator=a,Rand=b} = Term.dest_comb t
	      in   let val (cla,(fvla,bvla)) = get_ids a
	              and (clb,(fvlb,bvlb)) = get_ids b
	          in  (union cla clb,(union fvla fvlb,union bvla bvlb))
		  end
	      end
	     )
	   else raise get_ids_FAIL;


(* Functions to extract components from the get_ids triple *)

fun get_consts t = (fst o get_ids) t;

   (* : (term -> term list) *)


fun get_freevars t = (fst o snd o get_ids) t;

   (* : (term -> term list) *)


fun get_boundvars t = (snd o snd o get_ids) t;

   (* : (term -> term list) *)


(* Function to obtain a list of the types which occur in a term *)

(* The lists of constants, free-variables and bound-variables are        *)
(* concatenated => The resulting identifiers are converted to their types, *)
(* and then any repetitions are removed =>                                 *)

fun get_types t =

   (* : (term -> type list) *)

    let val (cl,(fvl,bvl)) = get_ids t
       fun get_typ t = let val {Name=_,Ty=ty} = (Term.dest_const t handle _ => Term.dest_var t) 
		       in ty 
		       end
   in  TRS_sets.remove_rep (map get_typ (cl @ fvl @ bvl))
   end;


(*--------------------------------------------------------------------------*)


(* Function which applied to a HOL type returns true if the type is of the *)
(* form `:*...` or `:op`, otherwise false is returned =>                     *)

fun is_primtype typ = 
	let val {Args=args,Tyop=_} = Type.dest_type typ
        in null args
	end
	handle _ => true;

   (* : (type -> bool) *)


(* Function which applied to a HOL type returns a list containing simply *)
(* the type itself if it is "primitive' or the types from which it is    *)
(* constructed otherwise =>                                                *)

fun subtypes typ =

   (* : (type -> type list) *)

   if (is_primtype typ)
   then [typ]
   else 
	let val {Args=args,Tyop=_} = Type.dest_type typ
        in args
	end;


(* Function to break-up a type into its "primitive' types *)

(* The function uses the predicate is_primtype, defined above => If the     *)
(* type is not "primitive', a list of the component types is obtained, to *)
(* which the function is applied recursively => The resulting list of lists *)
(* is then "flattened' to give a list of "primitive' types, from which    *)
(* any repetitions are removed =>                                           *)

fun prim_subtypes typ =

   (* : (type -> type list) *)

   if (is_primtype typ)
   then [typ]
   else (TRS_sets.remove_rep o flatten o (map prim_subtypes) o subtypes) typ;


(* Function to obtain a list of the "primitive' types occurring in a term *)

(* A list of the types occurring in the term is obtained => Each of these *)
(* types is converted to a list of its "primitive' types => The resulting *)
(* list of lists is then "flattened', and any repetitions are removed =>  *)

fun get_primtypes t =

   (* : (term -> type list) *)

   (TRS_sets.remove_rep o flatten o (map prim_subtypes) o get_types) t;


(* Function to obtain a list of the "primitive' polymorphic types in a term *)

fun get_primvartypes t = filter Type.is_vartype (get_primtypes t);

   (* : (term -> type list) *)


(* Function to merge two association lists, failing if the lists are *)
(* inconsistent =>                                                     *)

(* The first element of the pair at the head of l2 is looked-up in l1 => If the *)
(* second element of the pair obtained is equal to the second element of the  *)
(* pair at the head of l2, then the head of l2 is discarded => Otherwise, the   *)
(* merge fails => If the look-up in l1 fails, the head of l2 is retained =>       *)
local

 exception merge_FAIL;
 fun merge l1 l2 =

   (* : (( * * ** ) list -> ( * * ** ) list -> ( * * ** ) list) *)

   if (null l2)
   then l1
   else (( let val p = assoc (fst (hd l2)) l1
          in  if (p = snd (hd l2))
              then (merge l1 (tl l2))
              else raise merge_FAIL
	  end)
        handle NOT_FOUND => ((hd l2)::(merge l1 (tl l2)))

        );


(* Function to merge two "match' lists => *)

 fun join (avtl,attl) (bvtl,bttl) =

   (* : ((Term.term * Term.term) list * (type * type) list ->   *)
   (*    (Term.term * Term.term) list * (type * type) list ->   *)
   (*    (Term.term * Term.term) list * (type * type) list    ) *)

   (merge avtl bvtl,merge attl bttl) handle merge_FAIL => raise TRS_extents.NO_MATCH;


(* Function to remove a bound-variable from a "match' list => *)

(* Any pairs in the variable-term association list which have the *)
(* bound-variable as their first element are filtered out =>        *)

 fun remove_bv bv (vtl,ttl) =

   (* : (Term.term -> (Term.term * Term.term) list * (type * type) list ->    *)
   (*                 (Term.term * Term.term) list * (type * type) list) *)

   (filter ( fn x => not ((fst x) = bv)) vtl,ttl);


(* Function for matching two types => *)

(* The first type given, p, must be the more general =>                         *)

(* If p is a simple polymorphic type (i.e => one containing no constructors)    *)
(* then it can match any type => A single item association list is constructed  *)
(* from the two types in such a case =>                                         *)

(* If p is not a simple type, it is split up into a constructor and a list of *)
(* simpler types => An attempt is made to split t, also => If this fails, then no *)
(* match can be made => If the constructors obtained from p and t are different *)
(* then the match must fail => The lists of simpler types obtained from         *)
(* decomposing p and t are converted to a list of pairs, the match failing if *)
(* the original lists were not of the same length => The function is then       *)
(* applied recursively to each pair of the new list, and the results are      *)
(* merged => If merging fails, the whole match fails =>                           *)

 fun match_type p t =

   (* : (type -> type -> (type * type) list) *)

   if (Type.is_vartype p)
   then [(p,t)]
   else  let val {Tyop=pc,Args=ptypl} = Type.dest_type p
             and {Tyop=tc,Args=ttypl} = ((Type.dest_type t) handle _ => raise TRS_extents.NO_MATCH)
         in  if (pc = tc)
            then ((itlist merge (map ( fn (x,y) => match_type x y)
                                        (combine (ptypl,ttypl))) [])
                 handle merge_FAIL => raise TRS_extents.NO_MATCH
                 )
            else raise TRS_extents.NO_MATCH
         end;


(* Function for matching two terms => *)

(* The first term given, p, must be the more general =>                        *)

(* The function consists of four cases =>                                      *)

(* p is a constant => If t is not a constant, the match fails => If the names of *)
(* p and t are different, the match fails => Constants cannot be wildcards, so *)
(* only the types need adding to the "match' list => One might think the match *)
(* should fail if the types are different, but this is not the case =>         *)
(* Consider the "=' function, for instance => The types must match, however =>   *)

(* p is a variable => A variable can match any term, provided its type can be  *)
(* matched to that of the term =>                                              *)

(* p is an abstraction => An abstraction can only match another abstraction =>   *)
(* p and t are decomposed into their bound-variables and bodies => The bound-  *)
(* variables are matched to obtain the type matchings => The bodies are also   *)
(* matched => The resultant matchings are then merged, and the bound-variable  *)
(* is then removed from the variable-term list to allow for renaming of      *)
(* bound-variables => Note that the merge has to be done before the bound-var => *)
(* is removed to ensure the bound-variables correspond in the body =>          *)

(* p is a combination => A combination can only match another combination =>     *)
(* p and t are decomposed into their rators and rands => The two rators are    *)
(* matched against each other => The two rands are matched => Then the resulting *)
(* "match' lists are merged =>                                                 *)

 exception match_term_FAIL;
 fun match_term p t =

   (* : (Term.term -> Term.term -> (Term.term * Term.term) list * (type * type) list) *)

   if (Term.is_const p) then if (Term.is_const t)
                        then let val {Name=n1,Ty=_} = Term.dest_const p
			         val {Name=n2,Ty=_} = Term.dest_const t
			     in
				 if (n1 = n2)
                                 then ([],match_type (Term.type_of p) (Term.type_of t))
                                 else raise TRS_extents.NO_MATCH
			     end
                        else raise TRS_extents.NO_MATCH
   else if (Term.is_var p) then ([(p,t)], match_type (Term.type_of p) (Term.type_of t))
   else if (Term.is_abs p) then
     ( let val {Bvar=pbv,Body=pbody} = Term.dest_abs p
           and {Bvar=tbv,Body=tbody} = ((Term.dest_abs t) handle _ => raise TRS_extents.NO_MATCH)
      in  remove_bv pbv (join (match_term pbv tbv) (match_term pbody tbody))
      end)
   else if (Term.is_comb p) then
     ( let val {Rator=prator,Rand=prand} = Term.dest_comb p
           and {Rator=trator,Rand=trand} = ((Term.dest_comb t) handle _ => raise TRS_extents.NO_MATCH)
      in  join (match_term prator trator) (match_term prand trand)
      end)
   else raise match_term_FAIL;


(* Function to match a term pattern inside a term *)

(* The function applies match_term to the pattern and the term => If this fails *)
(* the function is called recursively on any possible sub-terms of the term =>  *)
(* If all these attempts to match fail, the whole evaluation fails =>           *)

 fun match_internal_term p t =

   (* : (Term.term -> Term.term -> (Term.term * Term.term) list * (type * type) list) *)

   match_term p t
   handle _ => (match_internal_term p (Term.rator t))
   handle _ => (match_internal_term p (Term.rand t))
   handle _ => (match_internal_term p (let val {Body=body,Bvar=_} = Term.dest_abs t in body end))
   handle _ => raise TRS_extents.NO_MATCH;

in
(*----------------------------------------------------------------------------*)


(* Abstract datatype for wildcard variables to be used in pattern matching *)

exception wildvar_FAIL;
datatype wildvar = WILDVAR of Term.term;

   (* Function to convert a wildvar into a term *)

   fun show_wildvar (WILDVAR w) = w;

       (* : (wildvar -> Term.term) *)

   (* Function to make a wildvar from a term => The term must be a variable *)

   fun make_wildvar t =

       (* : (Term.term -> wildvar) *)

       if (Term.is_var t)
       then (WILDVAR t)
       else raise wildvar_FAIL;

(* Function to make a list of wildvars out of a list of terms *)

fun wildvarlist varl = map make_wildvar varl;

   (* : (Term.term list -> wildvar list) *)


(*----------------------------------------------------------------------------*)


(* Abstract datatype for wildcard types to be used in pattern matching *)

exception wildtype_FAIL;
datatype wildtype = WILDTYPE of Type.hol_type;

   (* Function to convert a wildtype into a hol_type *)

   fun show_wildtype (WILDTYPE w) = w;

       (* : (wildtype -> Type.hol_type) *)

   (* Function to make a wildtype from a hol_type =>         *)
   (* The type must be a "primitive' polymorphic type => *)

   fun  make_wildtype t =

       (* : (Type.hol_type -> wildtype) *)

       if (Type.is_vartype t)
       then (WILDTYPE t)
       else raise wildtype_FAIL;


(* Function to make a list of wildtypes out of a list of types *)

fun wildtypelist typl = map make_wildtype typl;

   (* : (type list -> wildtype list) *)


(*----------------------------------------------------------------------------*)


(* Abstract datatype for patterns used to match terms *)

exception termpattern_FAIL;
datatype termpattern = TERMPATTERN of (Term.term * wildvar list * wildtype list);

   (* Function to convert a termpattern to its representing type *)

   fun show_termpattern (TERMPATTERN p) = p;

       (* : (termpattern -> (Term.term * wildvar list * wildtype list)) *)

   (* Function to make a termpattern from a term, a list of wildcard variables *)
   (* and a list of wildcard types =>                                            *)

   fun make_termpattern (tm,wvl,wtl) =

       (* : ((Term.term * wildvar list * wildtype list) -> termpattern) *)

       (* Convert wildcard variables to their representing variables *)

        let val varl = map show_wildvar wvl

       (* Convert wildcard types to their representing type *)

       and typl = map show_wildtype wtl

       (* Form a termpattern if and only if the lists of wildcard variables *)
       (* and wildcard types are sets (i.e => contain no repetitions) and all *)
       (* the wildcard variables specified are free variables in tm and all *)
       (* the wildcard types specified are "primitive' polymorphic types    *)
       (* occurring in tm =>                                                  *)

       in  if (TRS_sets.no_rep varl) then
              if (TRS_sets.no_rep typl) then
                 if (TRS_sets.is_subset (get_freevars tm) varl) then
                    if (TRS_sets.is_subset (get_primvartypes tm) typl) then

                       (TERMPATTERN  (tm,wvl,wtl))

                    else raise termpattern_FAIL
                 else raise termpattern_FAIL
              else raise termpattern_FAIL
           else raise termpattern_FAIL
	end;


(* Function to convert a termpattern into its representing type, and the *)
(* wildvars and wildtypes within that to their representing types =>       *)
(* So, function makes all of a termpattern visible =>                      *)

fun show_full_termpattern p =

   (* : (termpattern -> (Term.term * Term.term list * type list)) *)

   let 
	val (tm,wvl,wtl) = show_termpattern p
   in  
	(tm,(map show_wildvar wvl),(map show_wildtype wtl))
   end;


(* Function to make a termpattern from a term, a list of terms, and a list of *)
(* types => The term represents the pattern => The list of terms represents the   *)
(* variables which are to be taken as wildcards, and the list of types        *)
(* represents the "primitive' polymorphic types which are to be taken as      *)
(* wildcards =>                                                                 *)

fun make_full_termpattern (tm,terml,typel) =

   (* : ((Term.term * Term.term list * type list) -> termpattern) *)

   make_termpattern (tm,wildvarlist terml,wildtypelist typel);


(* Function to make a termpattern out of a term by using the free variables in *)
(* the term as wildvars and the "primitive' polymorphic types as wildtypes =>    *)

fun autotermpattern t =

   (* : (Term.term -> termpattern) *)

   make_full_termpattern (t,get_freevars t,get_primvartypes t);


(*----------------------------------------------------------------------------*)


(* Abstract datatype for the result of matching a termpattern against a term *)

datatype matching = MATCHING of ((wildvar * Term.term) list * (wildtype * Type.hol_type) list);

   (* Function to convert a matching to its representing type *)

   fun show_matching (MATCHING m) = m;

       (* : (matching -> ((wildvar * Term.term) list * (wildtype * type) list)) *)

   (* A matching with no bindings *)

	   val null_matching = MATCHING ([],[]);
	
	       (* : (matching) *)
	
	   (* Function to form a matching from a termpattern and a term *)
	
	   fun make_matching p t =
	
	       (* : (termpattern -> Term.term -> matching) *)
	
	           (* Extract low-level components of termpattern *)
	
	            let val (tm,varl,typl) = show_full_termpattern p
	
	               (* Use "match_term' to attempt a matching of the template tm  *)
	               (* against the term t => If this fails, "make_matching' fails =>  *)
	               (* If it succeeds the (Term.term * Term.term) list * (type * type) list *)
	               (* returned by "match_term' is bound to the pair (vpl,tpl)    *)
	               (* for further analysis/processing =>                           *)
	
	                val (vpl,tpl) = match_term tm t
	
	               (* The (Term.term * Term.term) list component returned by "match_term'   *)
	               (* is a list of pairs such that the first element of the pair  *)
	               (* is a variable in tm, and the second element of the pair is  *)
	               (* the term in t that the variable has been matched to =>        *)
	
	               (* Bound-variables in tm do not appear in the result of        *)
	               (* "match_term' => Some of the variables which do appear may not *)
	               (* have been specified as wildvars => The match must fail if     *)
	               (* such a variable does not (when its type has been            *)
	               (* instantiated) match itself in the list returned by          *)
	               (* "match_term' => The (type * type) list, returned by           *)
	               (* "match_term' is used to perform the instantiation =>          *)
	
	               (* Types are dealt with similarly, except that there is no     *)
	               (* equivalent action to instantiation =>                         *)
	
	               (* The matching we are trying to construct should look like    *)
	               (* the result of "match_term' except that the variables and    *)
	               (* types from tm should be converted to wildcards, and only    *)
	               (* those of them that appear as wildcards in the termpattern   *)
	               (* should be included =>                                         *)
	
	               (* Now we know what we are trying to achieve,  val us define    *)
	               (* some functions to help us =>                                  *)
	
	               (* f is used to convert the term or type which is representing *)
	               (* a wildcard into the appropriate wildcard type =>              *)
	
	               fun f w (a,b) = ((w a),b)
	
	                  (* : ((* -> **) -> (* * ***) -> (** * ***)) *)
	
	               (* "instant_type' instantiates the type of a variable using a  *)
	               (* (type * type) list in which the first element of each pair  *)
	               (* is a "primitive' type => The embedded function "change_type'  *)
	               (* does the real work => "instant_type' splits the variable into *)
	               (* its name and type, applies "change_type' to the type, and   *)
	               (* then reconstructs the variable using the new type =>          *)
	
	               fun instant_type ttl v =
	
	                  (* : ((type * type) list -> Term.term -> Term.term) *)
	
	                  (* "change_type' instantiates a type => If the type is        *)
	                  (* "primitive', it is looked-up in the instantiation list =>  *)
	                  (* If found, the corresponding instance is returned => If not *)
	                  (* the type itself is returned => If the type is              *)
	                  (* not "primitive', it is decomposed into a constructor and *)
	                  (* a list of simpler types => Each of the latter are then     *)
	                  (* instantiated, and the type is reconstructed =>             *)
	
	                 ( let fun change_type ttl typ =
	
	                     (* : ((type * type) list -> type -> type) *)
	
	                     if (is_primtype typ)
	                     then (assoc typ ttl) handle _ => typ
	                     else ( let val {Tyop=s,Args=l} = Type.dest_type typ
	                           in  Type.mk_type {Tyop=s,Args=(map (change_type ttl) l)}
				   end)
	                  in ( let val {Name=s,Ty=t} = Term.dest_var v
	                      in  Term.mk_var {Name=s,Ty=(change_type ttl t)}
			      end)
			  end
	                 )
	
	                   (* "build' filters xxl removing any pairs whose first  *)
	                   (* element is not in xl => If lf applied to the first    *)
	                   (* element of such a pair is not equal to the second   *)
	                   (* element of the pair, then the match being performed *)
	                   (* is failed =>                                          *)
	
	                   (* "build' is used to build a matching from a "match'  *)
	                   (* list and a wildcard list => Any variable or type in   *)
	                   (* the "match' list but not in the wildcard list must  *)
	                   (* match itself (allowing for type instantiation -     *)
	                   (* hence the need for lf), and will not be included in *)
	                   (* the result =>                                         *)
	
	               in ( let fun build lf xl xxl =
	
	                      (* : ((* -> **) -> * list -> (* * **) list ->          *)
	                      (*                                      (* * **) list) *)
	
	                      if (null xxl)
	                      then []
	                      else if (mem ((fst o hd) xxl) xl)
	                           then (hd xxl)::(build lf xl (tl xxl))
	                           else if ((lf o fst o hd) xxl = (snd o hd) xxl)
	                                then (build lf xl (tl xxl))
	                                else raise TRS_extents.NO_MATCH
	
	                   (* Note : assumes all variables which could be wildvars   *)
	                   (*        appear in the matching returned by "match_term' *)
	
	                   in MATCHING (
	                        (map (f make_wildvar)
	                                (build (instant_type tpl) varl vpl)),
	                        (map (f make_wildtype) (build ( fn x => x) typl tpl)))
			   end
	                  )
			end;
	
	   (* Function to combine two (consistent) matchings into a single matching *)
	
	   (* Split the two matchings into wildvar and wildtype "match' lists => Merge *)
	   (* the two resulting wildvar lists and the two resulting wildtype lists =>  *)
	   (* If either merge fails, the match fails =>                                *)
	
	   fun join_matchings (MATCHING m) (MATCHING n) =
	
	       (* : (matching -> matching -> matching) *)
	
	            let val (mwvl,mwtl) = m
	           and (nwvl,nwtl) = n
	           in  MATCHING ((merge mwvl nwvl),(merge mwtl nwtl))
		   end
	                  handle merge_FAIL => raise TRS_extents.NO_MATCH;;
	
	
	(* Function to convert a matching into its representing type, and the *)
	(* wildvars and wildtypes within that to their representing types =>    *)
	(* So, function makes all of a matching visible =>                      *)
	
	fun show_full_matching m =
	
	   (* : (matching -> ((Term.term * Term.term) list * (type * type) list)) *)
	
	        let val (wvl,wtl) = show_matching m
	           fun f (w,t) = ((show_wildvar w),t)
	           fun g (w,t) = ((show_wildtype w),t)
	        in  ((map f wvl),(map g wtl))
	        end;
	
	
	(* Function to lookup a wildvar in a matching, and return the term to *)
	(* which it is bound =>                                                 *)
	
	exception UNKNOWN_WILDVAR;
	fun match_of_var m wv =
	
	   (* : (matching -> wildvar -> Term.term) *)
	
	   ((assoc wv) o fst o show_matching) m
	   handle NOT_FOUND =>
	      raise UNKNOWN_WILDVAR;
	
	
	(* Function to lookup a wildtype in a matching, and return the type to *)
	(* which it is bound =>                                                  *)
	exception UNKNOWN_WILDTYPE;
	fun match_of_type m wt =
	
	   (* : (matching -> wildtype -> type) *)
	
	   ((assoc wt) o snd o show_matching) m
	   handle NOT_FOUND =>
	      raise UNKNOWN_WILDTYPE;
	
	
	(*----------------------------------------------------------------------------*)
	
	
	(* Datatype for lazy evaluation of alternate matchings *)
	
	(* Nomatch means there is no way to match =>                                    *)
	(* Match means there is at least one way to match, and specifies the matching *)
	(* (which may be null) => The second element of the pair is a function to       *)
	(* generate any other matchings if they exist =>                                *)
	
	datatype result_of_match = Nomatch
	                        | Match of matching * (unit -> result_of_match);
	
	
	(* Abbreviation for a result_of_match which is a match with no bindings *)
	
	val Match_null = Match(null_matching,( fn () => Nomatch));
	
	
	(* Function to append two lazy lists ("result_of_matches') *)
	
	(* "approms' appends two "result_of_matches' which are essentially just lazy *)
	(* lists of matchings => The result must be kept as lazy as possible => This     *)
	(* function is also used to OR two "result_of_matches', since this operation *)
	(* corresponds exactly to appending them =>                                    *)
	
	(* The arguments to "approms' are actually functions from unit to a          *)
	(* "result_of_match', so that as little evaluation as necessary is done =>     *)
	
	(* The function is defined in an analogous way to "append' on lists =>         *)
	
	fun approms rom1fn rom2fn () =
	
	   (* : ((unit -> result_of_match) -> (unit -> result_of_match) ->           *)
	   (*                                             (unit -> result_of_match)) *)

	   case (rom1fn ())
	   of (Nomatch)  => (rom2fn ())
	    | (Match (m,romfn))  => (Match (m,approms romfn rom2fn));


	(* Function to convert a Boolean value to a result_of_match *)

	fun bool_to_rom b =

	   (* : (bool -> result_of_match) *)

	   if b
	   then Match_null
	   else Nomatch;


	(* Function to convert a result_of_match to a Boolean value *)

	(* Note that information may be thrown away in this process => *)

	fun rom_to_bool Nomatch = false 
	  | rom_to_bool (Match _) = true;
	
	   (* : (result_of_match -> bool) *)
	

	(* Abbreviation for the datatype representing side-conditions *)

	(* When applied to a matching, a side-condition performs tests on the        *)
	(* bindings in the matching, and returns a "lazy list' of any successful new *)
	(* bindings =>                                                                 *)


end;

end;

type side_condition = TRS_struct.matching -> TRS_struct.result_of_match;

