(*  Title: 	unify
    Author: 	Lawrence C Paulson, Cambridge University Computer Laboratory
    Copyright   Cambridge University 1986
*)

(*Higher-Order Unification

matchcopy could be written in continuation style, instead of stream grinding
*)


(*Unification options*)

val unif_options =
   {trace_bound = ref 10,	(*tracing starts above this depth*)
    search_bound = ref 20,	(*unification quits above this depth*)
    trace_simp = ref false};	(*print dpairs before calling SIMPL*)

val {trace_bound,search_bound,trace_simp} = unif_options;

type binderlist = (string*arity) list;

type dpair = binderlist * term * term;

(*print second term first because it may be rigid or flexible, 
  other term is always flexible.*)
fun print_dpairs sign (env,dpairs)   =
  let fun pterm tm = print_term sign (norm_term env tm);
      fun prb [] = ()
	| prb rbinder = (prs"<";  pri (length rbinder);  prs"> ");
      fun pdp (rbinder,t,u) =
           (prb rbinder;  pterm u;  prs ("  =?=  ");  pterm t;  prs"\n")
  in  seq pdp dpairs  end;



(*Put a term into head normal form for unification.
  Operands need not be in normal form.  Does eta-expansions on the head,
  which involves renumbering (thus copying) the args.  To avoid this 
  inefficiency, avoid partial application:  if an atom is applied to
  any arguments at all, apply it to its full number of arguments.
  For
    rbinder = [(x1,ary1),...,(xm,arym)]		(user's var names preserved!)
    args  =   [arg1,...,argn]
  the value of 
      (xm,...,x1)(head(arg1,...,argn))  remains invariant.
*)

fun head_norm (env,tm) : term =
  let fun hnorm (Var (uname,ary)) = 
	    (case elookup (env,uname) of
		Some u => head_norm (env, u)
	      | None   => raise same)
	| hnorm (Abs(name,ary,body)) =  Abs(name, ary, hnorm body)
	| hnorm (Abs(_,_,body) $ rand) =
	    head_norm (env, subst_bounds(~1, [rand], body))
	| hnorm (rator $ rand) =
	    (case hnorm rator of
	       Abs(_,_,body) =>
		 head_norm (env, subst_bounds(~1, [rand], body))
	     | nrator => nrator $ rand)
	| hnorm _ =  raise same
  in  hnorm tm  handle same=> tm  end;


(*finds arity of term without checking that combinations are consistent
  rbinder holds arities of bound variables*)
fun fastarity (rbinder, tm: term) : arity = case tm of
    rator$rand => 
      (case (fastarity (rbinder, rator)) of
	    arand-->ary => ary
	  | Ground _ => raise term_error with ("fastarity: Comb", [tm]))
  | Const (_,ary) => ary
  | Param (_,ary) => ary
  | Bound bno => snd (nth_elem (bno,rbinder))
  	      handle list=> raise term_error with ("fastarity: Bound", [tm])
  | Var (_,ary) => ary 
  | Abs (_,ary,body) =>  ary --> fastarity (("",ary) :: rbinder, body);


(* Eta normal form *)

fun etif (Ground _, tm) = tm
  | etif (ary1-->ary2, tm) =
	Abs("", ary1, etif(ary2, incr_boundvars 1 tm $ Bound 0));

fun eta_norm (rbinder, Abs(name,ary,body)) : term =
	Abs(name, ary, eta_norm ((name,ary) :: rbinder,body))
  | eta_norm (rbinder, tm) = etif(fastarity(rbinder,tm), tm);

exception unify : unit;		(*non-unifiable dpair found, not an error*)



(*Extends an rbinder with two terms of a new disagreement pair.
  Checks that binders have same length, since terms should be eta-normal.
  Does NOT compare corresponding arities!
  Uses nonempty variable name (if present) to preserve user's naming.*) 
fun new_dpair (rbinder, Abs(name1,ary1,body1), Abs(name2,ary2,body2)) =
      if name1="" then new_dpair((name2,ary1) :: rbinder, body1, body2)
      		  else new_dpair((name1,ary1) :: rbinder, body1, body2)
  | new_dpair (_, Abs _, _) = raise term_error with ("new_dpair", [])
  | new_dpair (_, _, Abs _) = raise term_error with ("new_dpair", [])
  | new_dpair (rbinder, t1, t2) = (rbinder, t1, t2);


fun head_norm_dpair (env, (rbinder,t,u)) : dpair =
     new_dpair (rbinder,
		eta_norm (rbinder, head_norm(env,t)),
	  	eta_norm (rbinder, head_norm(env,u)));


fun depslookup (Envir{depol,...}, xname) : term list = 
  case xsearch (depol,xname) of
      None => []
    | Some deps => deps;


(*OCCUR CHECK
  Does the uvar occur in the term tm?  
  two forms of search, for whether there is a rigid path to the current term.
  "seen" is list of variables passed thru, is a memo variable for sharing

  this version searches for nonrigid occurrence, returns true if found *)
fun occurs_terms (seen: (indexname list) ref,
 		  env: envir, uname: indexname, tms: term list): bool =
  let fun occurs [] = false
	| occurs (tm::tms) =  occur tm  orelse  occurs tms
      and occur (Const _)  = false
	| occur (Bound _)  = false
	| occur (Param (pname,_)) = occurs (depslookup(env,pname))
	| occur (Var (uname',_))  = 
	    if uname' xmem !seen then false
	    else if uname=uname' then true
	      (*no need to lookup: uname has no assignment*)
	    else (seen := uname':: !seen;  
	          case  elookup(env,uname')  of
		      None    => false
		    | Some tm => occur tm)
	| occur (Abs(_,_,body)) = occur body
	| occur (rator$rand) = occur rand  orelse   occur rator
  in  occurs tms  end;



(* f(a1,...,an)  ---->   (f,  [a1,...,an])  using the assignments*)
fun head_of_in (env,tm) : term = case tm of
    rator$_ => head_of_in(env,rator)
  | Var (uname,_) => (case  elookup(env,uname)  of  
			Some u => head_of_in(env,u)  |  None   => tm)
  | _ => tm;


(* Rigid occur check
Returns 2 if it finds a rigid occurrence of the variable,
        1 if it finds a nonrigid path to the variable.
  continues searching for a rigid occurrence even if it finds a nonrigid one.

Condition for detecting non-unifable terms: [ section 5.3 of Huet (1975) ]
   a rigid path to the variable, appearing with no arguments.
Here completeness is sacrificed in order to reduce danger of divergence:
   reject ALL rigid paths to the variable.
Could check for rigid paths to bound variables that are out of scope.  
Not necessary because the fixedpoint test looks at variable's ENTIRE rbinder.

Treatment of head(arg1,...,argn):
If head is a variable then no rigid path, switch to nonrigid search
for arg1,...,argn. 
If head is an abstraction then possibly no rigid path (head could be a 
   constant function) so again use nonrigid search.  Happens only if
   term is not in normal form. 

Warning: finds a rigid occurrence of ?f in ?f(t).
  Should NOT be called in this case: there is a flex-flex unifier
*)
fun rigid_occurs_term (seen: (indexname list) ref,
 		  env: envir,  uname: indexname,  tm) : int = 
  let fun nonrigid tm =  if occurs_terms (seen,env,uname,[tm]) then 1 else 0;
      fun occurs [] = 0
	| occurs (tm::tms) =
            (case occur tm of
               2 => 2
             | oc1 =>  (case occurs tms of 0 => oc1  |  oc2 => oc2))
      and occomb (rator$rand) =
            (case occur rand of
               2 => 2
             | oc1 =>  (case occomb rator of 0 => oc1  |  oc2 => oc2))
        | occomb tm = occur tm
      and occur (Const _)  = 0
	| occur (Bound _)  = 0
	| occur (Param (pname,_)) = occurs (depslookup(env,pname))
	| occur (Var (uname',_))  = 
	    if uname' xmem !seen then 0
	    else if uname=uname' then 2
	    else (seen := uname':: !seen;  
	          case  elookup(env,uname')  of
		      None    => 0
		    | Some tm => occur tm)
	| occur (Abs(_,_,body)) = occur body
	| occur (tm as rator$_) =  (*switch to nonrigid search?*)
	   (case head_of_in (env,rator) of
	      Var (uname',_) => (*uname' is not assigned*)
		if uname=uname' then 2  (*rigid occurrence*)
		else  nonrigid tm
	    | Abs(_,_,body) => nonrigid tm (*not in normal form*)
	    | _ => occomb tm)
  in  occur tm  end;



(*Is the term eta-convertible to a single variable with the given rbinder?
  Examples: ?C   ?C'(B.0)   ?C''(B.1,B.0)
  Result is var name for use in SIMPL. *)
fun get_eta_var ([], _, Var(uname,_))  =  Some uname
  | get_eta_var (_::rbinder, n, rator $ Bound bno) =
	if  n=bno  then  get_eta_var (rbinder, n+1, rator)  else  None
  | get_eta_var _ = None;


(* ([xn,...,x1], t)   ======>   (x1,...,xn)t *)
fun rlist_abs ([], body) = body
  | rlist_abs ((id,ary)::pairs, body) = rlist_abs(pairs, Abs(id, ary, body));


fun fixedpoint (env, rbinder, t, u) : envir option =
    case get_eta_var(rbinder,0,t) of
       Some uname =>
          (case rigid_occurs_term (ref[], env, uname, u) of
	      0 => 
		let val tm = rlist_abs(rbinder, u);
		    val env' = eupdate ((uname,tm),  env)
		in  Some env'  end
	    | 1 =>  None
	    | 2 =>  raise unify)  (*Rigid path*)
      | None =>  None;



(*flexflex: the flex-flex pairs,  flexrigid: the flex-rigid pairs
  NO fixedpoint substitution for flex-flex pairs:
    may create nonrigid paths, which mess up real fixedpoints*)
fun SIMPL0  (dp0, all as (env,flexflex,flexrigid))
  : envir * dpair list * dpair list =
  let val dp as (rbinder,t,u) = head_norm_dpair(env,dp0);
      val headt = head_of t  and  headu = head_of u;
      fun SIMRANDS (ratort$randt, ratoru$randu) =
	    SIMPL0 ((rbinder,randt,randu), SIMRANDS(ratort,ratoru))
	| SIMRANDS _ = all;
      fun SIMPL_IF true = SIMRANDS(t,u)
        | SIMPL_IF false = raise unify
  in
  if alphaconv (t,u) then all
  else case (headt,headu) of
       (Var _, Var _) => (env, dp::flexflex, flexrigid)
     | (Var(uname,_), _) => 
	  (case fixedpoint (env,rbinder,t,u) of
	 	Some env' => (env', flexflex, flexrigid)
	      | None      => (env, flexflex, dp::flexrigid))
     | (_, Var(uname,_)) =>
	  (case fixedpoint (env,rbinder,u,t) of
	 	Some env' => (env', flexflex, flexrigid)
	      | None      => (env, flexflex, (rbinder,u,t)::flexrigid))
     | (Const(idt,_), Const(idu,_)) => SIMPL_IF (idt=idu)
     | (Bound bnt,    Bound bnu) => SIMPL_IF (bnt=bnu)
     | (Param(idt,_), Param(idu,_)) => SIMPL_IF (idt=idu)
     | _ => raise unify
  end;


fun changed (env, rator$_) = changed (env,rator)
  | changed (env, Var (uname,_)) =
      (case elookup(env,uname) of None=>false  |  _ => true)
  | changed _ = false;


(*Recursion needed if any of the 'head variables' have been updated
  Clever would be to re-do just the affected dpairs*)
fun SIMPL (env,dpairs) : envir * dpair list * dpair list =
  let val all as (env',flexflex,flexrigid) =
            itlist_right SIMPL0  (dpairs, (env,[],[]));
      val dps = flexrigid@flexflex
  in  
  if exists (fn ((_,t,u)) => changed(env',t) orelse changed(env',u)) dps
     then  SIMPL (env',dps)  else all
  end;



(*computes head(Bound(n+k-1),...,Bound(n))  for arity [a1,...,ak]--->a *)
fun combound (head, n, []) = head
  | combound (head, n, _::arys) = combound (head,n+1,arys) $ (Bound n);


(*Makes the terms E1,...,Em,    where arys = [ary1...arym]. 
  Each Ei is   ?Gi(B.(n-1),...,B.0), and has arity aryi
  The B.j are bound vars of binder.
  The terms are not made in eta-normal-form, SIMPL does that later.  
  If done here, eta-expansion must be recursive in the arguments! *)
fun make_args (binder: arity list, env, []) = (env, [])   (*frequent case*)
  | make_args (binder: arity list, env, arys) : envir * term list =
       let fun funary ary = binder--->ary;
	   val (env', vars) = genvars (env, map funary arys)
       in  (env',  map (fn var=> combound(var,0,binder)) vars)  end;


(*like list_abs but no names for the bound variables*)
fun arities_abs ([],body) = body
  | arities_abs (ary::arys, body) = Abs("", ary, arities_abs(arys,body));

fun arity_abs (ary,t) = arities_abs(binder_arities ary, t);


(*MATCH taking "big steps".
  Copies the term u, using projection on targs whenever possible.
  A projection is allowed if simpl returns (rather then raising exception).
  Allocates new variables in projection on a higher-order argument,
    or if u is a variable (flex-flex dpair).
  IMITATION of a parameter requires occurs check on dependencies.  
  Returns sequence of every way of projecting/imitating u, to allow backtracking
  For example, projection in ?b'(?a) may be wrong if other dpairs constrain ?a.
  Sequence can be very long!  *)
fun matchcopy (tname, rbinder, targs, u, ed as (env,dpairs)) 
  : (term * (envir * dpair list)) sequence =
  let fun copycons uarg (uargs, (env, dpairs)) = 
	    maps(fn (uarg', ed') => (uarg'::uargs, ed'))
		(matchcopy (tname, rbinder, targs,  
			eta_norm(rbinder, head_norm(env,uarg)),
			(env, dpairs)));
      fun copyargs [] = scons( ([],ed), null_sequence)
	| copyargs (uarg::uargs) =
	    flat_sequence (maps (copycons uarg) (copyargs uargs));
      val (uhead,uargs) = strip_comb u;
      fun joinargs (uargs',ed') = (list_comb(uhead,uargs'), ed');
     (*attempt projection on argument with given arity*)
      val base = body_arity (fastarity (rbinder,uhead));
      val tarys = map (curry fastarity rbinder) targs;
      fun projenv (head, (arys,bary), targ, tail) = 
        if base=bary  then 
	  sequenceof (fn () =>  
	    let val (env',args) = make_args (tarys,env,arys);
		(*higher-order projection: plug in targs for bound vars*)
		fun plugin arg = list_comb(head_of arg, targs);
		val dp = (rbinder, list_comb(targ, map plugin args), u);
		val (env2,frigid,fflex) = SIMPL (env', dp::dpairs)
		    (*may raise exception unify*)
	    in  case elookup(env2,tname) of
		    None => Some ((list_comb(head,args), (env2, frigid@fflex)),
				  tail)
		  | Some _ => spull tail  (*forbid updating of tname*)
	    end
	    handle unify => spull tail)
        else tail;
      (*try projections*)
      fun PROJFUN (ary::arys, targ::targs) =
	       (projenv(Bound(length arys), strip_arity ary, targ, 
			PROJFUN(arys,targs)))
       	| PROJFUN ([], []) = (*try imitation*)
	      (case uhead of
		 Const _ => maps joinargs (copyargs uargs)
	       | Var _ => null_sequence  (*loop detected!*)
	       | Param (pname,_)  =>
		    if occurs_terms (ref[],env,tname,depslookup(env,pname))
		    then null_sequence   
		    else maps joinargs (copyargs uargs)
	       | _ => null_sequence)
     	| PROJFUN _ = raise term_error with ("PROJFUN", u::targs)
  in  case head_of u  of
	  Abs(name, ary, body) =>
	    maps(fn (body', ed') => (Abs (name,ary,body'), ed')) 
		(matchcopy (tname, (name,ary)::rbinder, 
			(map (incr_boundvars 1) targs) @ [Bound 0],
			body, ed))
        | uhead as Var (uname,uary)   => 
	    if  uname=tname  then PROJFUN(tarys, targs)
	    else (*have encountered a flex-flex dpair, make variable for t*)
	      let val (env', newhd) = genvar(env, tarys ---> Ground base);
           	  val tabs = combound(newhd,0,tarys);
                  val tsub = list_comb(newhd,targs)
              in  sequence_of_list[ (tabs, (env', (rbinder, tsub, u):: dpairs)) ]	      end
	| _ =>  PROJFUN(tarys, targs)
  end;



fun MATCH (env, (rbinder,t,u), dpairs) : (envir * dpair list) sequence = 
  let val (Var(tname,tary), targs) = strip_comb t;
      val tarys = binder_arities tary;
      fun new_dset (u', (env',dpairs')) =
	    (eupdate ((tname, arities_abs(tarys, u')), env'),   dpairs')
  in  maps new_dset  (matchcopy (tname, rbinder, targs, u, (env,dpairs)))
  end;


(*Update, checking Var-Var assignments: try to suppress higher indexes*)
fun vupdate((name,t), env) = case t of
      Var(name',ary) =>
	if name=name' then env	(*cycle!*)
	else if xless(name, name')  then  
	   (case elookup(env,name') of  (*if already assigned, chase*)
		None => eupdate((name', Var(name,ary)), env)
	      | Some u => vupdate((name,u), env))
	else eupdate((name,t), env)
    | _ => eupdate((name,t), env);


(*Flex-flex fixedpoint*)
fun ffixedpoints (env, rbinder, t, u) : envir option = 
  case get_eta_var(rbinder,0,t) of
      Some uname =>
	if  occurs_terms (ref[], env, uname, [u])  then  None
	else  Some (vupdate ((uname, rlist_abs(rbinder, u)), env))
    | None => None;


(*add a tpair if not there already;  should also check for swapped pairs*)
fun add_tpair (rbinder, (t0,u0), tpairs) : (term*term) list =
  let val t = rlist_abs(rbinder, t0)  and  u = rlist_abs(rbinder, u0);
      fun same(t',u') = alphaconv(t,t') andalso alphaconv(u,u')
  in  if exists same tpairs  then tpairs  else (t,u)::tpairs  end;


(*IF the flex-flex dpair is a fixedpoint THEN do it  ELSE  put in tpairs
  eliminates trivial tpairs like t=t, as well as repeated ones
  trivial tpairs can easily escape SIMPL:  ?A=t, ?A=?B, ?B=t gives t=t 
  Resulting tpairs MAY NOT be in normal form:  assignments may occur
    here.*)
fun flex_fixedpoint ((rbinder,t0,u0), (env,tpairs)) : envir * (term*term)list =
  let val t = norm_term env t0  and  u = norm_term env u0
  in  case  (head_of t, head_of u) of
      (Var(name1,_), Var(name2,_)) =>
	if name1=name2  then  (*occur check would get this wrong!*)
	    if alphaconv(t,u) then (env,tpairs) 
	    else (env, add_tpair (rbinder, (t,u), tpairs))
	else (case ffixedpoints (env,rbinder,t,u) of
		Some env' => (env',tpairs)
	      | None => 
		 (case ffixedpoints (env,rbinder,u,t) of
		    Some env' => (env',tpairs)
		  | None      => (env, add_tpair(rbinder, (t,u), tpairs))))
    | _ => raise term_error with ("flex_fixedpoint", [t,u])
  end;


(*Unify the dpairs in the environment.
  Returns flex-flex disagreement pairs NOT IN normal form. 
  SIMPL may raise exception "unify". *)
fun unifiers (sign, env, tus : (term*term)list) 
  : (envir * (term*term)list) sequence =
  let fun add_unify tdepth ((env,dpairs), resequence) =
        sequenceof (fn()=>
        let val (env',flexflex,flexrigid) = 
	     (if tdepth> !trace_bound andalso !trace_simp
	      then (prs"Enter SIMPL\n";  
		    print_dpairs sign (env,dpairs)) else ();
	      SIMPL (env,dpairs))
	in case flexrigid of
	    []   =>  Some (itlist_right flex_fixedpoint (flexflex, (env',[])),
			   resequence)
	  | dp::flexrigid' => 
	      if tdepth > !search_bound then
		  (prs"***Unification bound exceeded\n";  spull resequence)
	      else
	      (if tdepth > !trace_bound then
		  (prs"Enter MATCH\n";
		   print_dpairs sign (env', flexrigid@flexflex))
	       else ();
	       spull (itsequence_right (add_unify (tdepth+1))
   		         (MATCH (env',dp, flexrigid'@flexflex), resequence)))
	end
	handle unify => 
	  (if tdepth > !trace_bound then  prs"Failure node\n"  else ();
           spull resequence));
      val dps = map (fn(t,u)=> ([],t,u)) tus
  in add_unify 0 ((env,dps), null_sequence)  end;


(*For smash_flexflex1*)
fun var_head_of (env,t) : indexname * arity =
  case head_of (strip_abs_body (norm_term env t)) of
      Var(tname,tary) => (tname,tary)
    | _ => raise unify;  (*not flexible, cannot unify by trivial substitution*)


(*Eliminate a flex-flex pair by the trivial substitution, see Huet (1975)
  Unifies ?f(t,u) with ?g(t,u) by ?f, ?g -> %(x,y)?a, though just ?g->?f
    is a more general unifier.
  Unlike in Huet (1975), does not smash together all variables of same arity,
    for requires more work yet gives a less general unifier.*)
fun smash_flexflex1 ((t,u), env) : envir =
  let val (tname,tary) = var_head_of (env,t)
      and (uname,uary) = var_head_of (env,u);
      val (env', v) = genvar(env, Ground (body_arity tary))
  in  vupdate((tname, arity_abs(tary,v)),
	       vupdate((uname, arity_abs(uary,v)),  env'))
  end;


(*Smash all flex-flexpairs.  Should allow selection of pairs by a predicate?*)
fun smash_flexflex (env,tpairs) : envir =
  itlist_right smash_flexflex1 (tpairs, env);


(*Returns unifiers with no remaining disagreement pairs*)
fun smash_unifiers (sign, env, tus) : envir sequence =
    maps smash_flexflex (unifiers (sign,env,tus));


(*handy functional for tpairs*)
fun pairself f (x,y) = (f x, f y);


(*Prints u then t, like print_dpair*)
fun print_tpair sign (t,u) =
  (print_term sign u;  prs" = ";  print_term sign t;  prs"\n");


fun read_termpair sign (st,su) : term*term =
  let val t = read_term sign st
      and u = read_term sign su
  in  if arity_of t = arity_of u then (t,u)
      else raise term_error with ("read_termpair: arities", [t,u])  end;

fun normpairs env : (term*term) list -> (term*term) list =
  map (pairself (norm_term env));


(*test unification of terms wrt environment*)
fun tun sign (env, tus: (term*term) list) : unit =
  let val ppairs = seq (print_tpair sign);
      fun pr k (env',tpairs) = 
	(prs"Unifier ";  pri k;  prs"\n";  
	 ppairs (normpairs env' tus);  print_env sign env';
	 case tpairs of
               [] => ()
	    | _::_ => (prs"FlexFlex:\n";  ppairs tpairs))
  in  print_sequence pr 20 (unifiers (sign,env,tus))  end;


(*Test unification with parsed terms of signature*)
fun tun_read sign (spairs: (string*string) list) =
  tun sign (null_env, map (read_termpair sign) spairs);


(*Similar, for flex-flex unifiers*)
fun ftun sign (env, tus: (term*term) list) : unit =
  let fun pr k env' = 
	(prs"Unifier ";  pri k;  prs"\n";  
	 seq (print_tpair sign) (normpairs env' tus);  print_env sign env')
  in  print_sequence pr 20 (smash_unifiers (sign,env,tus))  end;


(*Test unification with parsed terms of signature*)
fun ftun_read sign (spairs: (string*string) list) =
  ftun sign (null_env, map (read_termpair sign) spairs);


