(*---------------------------------------------------------------------------
    Substitution in the lambda calculus.
 ---------------------------------------------------------------------------*)

load "bossLib"; open bossLib listTheory;
infix 8 by;

val _ = Defn.def_suffix := "";


(*---------------------------------------------------------------------------
        Untyped lambda calculus terms
 ---------------------------------------------------------------------------*)

Hol_datatype `lam = Var   of 'a 
                  | Const of 'b
                  | Comb  of lam => lam
                  | Abs   of 'a => lam`;


(*---------------------------------------------------------------------------
     Free variables of a term
 ---------------------------------------------------------------------------*)

val DEL =
 Define `(DEL (h::t) x = if x=h then t else h::DEL t x)
    /\   (DEL   []  __ = [])`;

val FV = 
 Define 
    `(FV (Var x) A    = if MEM x A then A else x::A) 
 /\  (FV (Const _) A  = A)
 /\  (FV (Comb M N) A = FV M (FV N A))
 /\  (FV (Abs v M) A  = let MFV = FV M A
                        in if MEM v A then MFV else DEL MFV v)`;

(*---------------------------------------------------------------------------
    Repeatedly prime "x" so that it isn't in the fixed list "L". The 
    definition is parametric in PRIME, so that one can defer the decision
    of how renaming actually gets done. This lets the "lam" type be
    polymorphic in the representation of variables. It means that we 
    can't prove AWAY terminates until an instantiation for PRIME is
    provided. (Well, not really: I think we could add other constraints
    on PRIME, e.g., that it generates an infinite set, and then prove
    that the (still abstract) algorithm terminates.)
 ---------------------------------------------------------------------------*)

val AWAY_defn =
  Hol_defn "AWAY" 
     `AWAY x = if MEM x L then AWAY (PRIME x) else x`;

(*---------------------------------------------------------------------------
     The substitution function. Terminates because it always recurses
     on smaller arguments.
 ---------------------------------------------------------------------------*)

val Var_subst = 
 Define `Var_subst (x,P) v Sc = 
            if MEM v Sc then Var(v) else 
            if v=x then P else (Var v)`;

val SUBST_defn =
 Defn.Hol_Rdefn 
   "SUBST"   
   `measure (\(x,y,z). lam_size (\v. 0) (\v. 0) y)`

   `(SUBST pr    (Var v)    Sc = Var_subst pr v Sc)
 /\ (SUBST  ___  (Const c)  __ = Const c)
 /\ (SUBST (x,Q) (Comb M N) Sc = Comb (SUBST (x,Q) M Sc) (SUBST (x,Q) N Sc))
 /\ (SUBST (x,Q) (Abs v M)  Sc = 
     if MEM x (v::Sc) then Abs v M else          (* x is bound in v::Sc *)
     if MEM x (FV M []) /\ MEM v (FV Q []) then  (* capture would happen *)
        let v' = AWAY (FV M (v::Sc)) PRIME v
        in Abs v' (SUBST (x,Q) (SUBST (v,Var v') M []) (v'::Sc))
     else Abs v (SUBST (x,Q) M (v::Sc)))`;

(*---------------------------------------------------------------------------
    Termination proof. Tricky because SUBST is nested. The easy cases
    get proved straightaway. The nested case proceeds by complete
    induction on the size of the lambda term, and then case analysis
    on the term structure. All the cases are easy except for the nested
    case, where the inductive hypothesis needs to be used twice.

    The proof is actually the termination proof for the auxiliary 
    function. The auxiliary function is the function from which the 
    specified SUBST function is derived. Thus one has to instantiate
    the recursion equations for the auxiliary function with the 
    termination relations (and then clean them up). This is why the
    equations E1 ... E4 are derived and applied in the termination
    proof. The abbreviation "Fn" is used purely for readability in
    the proofs.
 ---------------------------------------------------------------------------*)

val Fn_def = 
  Define
    `Fn = SUBST_tupled_aux (measure (\(x,y,z). lam_size (\v.0) (\v.0) y))`;

local val SOME aux_defn = Defn.aux_defn SUBST_defn
      val SOME ind  = Defn.ind_of aux_defn
      val [e1,e2,e3,e4] = CONJUNCTS (Defn.eqns_of aux_defn)
      val trivs = prove(list_mk_conj(Defn.tcs_of aux_defn),
                        TotalDefn.TC_SIMP_TAC [] [])
in
val [E1,E2,E3,E4a] = map (REWRITE_RULE [SYM Fn_def,trivs]) [e1,e2,e3,e4]
val E4 = simpLib.SIMP_RULE arith_ss 
       [combinTheory.o_DEF,prim_recTheory.measure_def,
              relationTheory.inv_image_def] E4a
val Fn_ind = simpLib.SIMP_RULE std_ss []
                 (REWRITE_RULE [SYM Fn_def,trivs] (DISCH_ALL ind))
val Fn_ind1 = Q.GEN `P` (DISCH_ALL (Q.GEN`v` 
               (Q.SPECL [`v`,`Var w`] (UNDISCH(SPEC_ALL Fn_ind)))))
end;

val lam_size_def = definition "lam_size_def";
val lamty = ty_antiq(Type`:('a,'b)lam`);

Defn.tgoal SUBST_defn;
REWRITE_TAC [SYM Fn_def]
  THEN TotalDefn.TC_SIMP_TAC [] []
  THEN RW_TAC std_ss [DECIDE `x<1+y = x<=y`]
  THEN REPEAT (POP_ASSUM (K ALL_TAC))
  THEN Q.SPEC_TAC (`AWAY (FV M (v::Sc)) PRIME v`, `w`) 
  THEN Q.SPEC_TAC (`[]`,`Scope`)
  THEN Q.ID_SPEC_TAC `M`
  THEN Q.ID_SPEC_TAC `v` 
  THEN recInduct Fn_ind1
  THEN RW_TAC std_ss []
  THENL
    [RW_TAC arith_ss [E1,lam_size_def,Var_subst],
     RW_TAC arith_ss [E2,lam_size_def],
     RW_TAC arith_ss [E3,lam_size_def],
     MP_TAC (REWRITE_RULE [lam_size_def,DECIDE `x<1+y = x<=y`,
              arithmeticTheory.ADD_CLAUSES] E4) 
     THEN RW_TAC std_ss [] 
     THENL
       [RW_TAC arith_ss [E4,lam_size_def],
        ??,
     MP_TAC E4 THEN ASM_REWRITE_TAC[]
RW_TAC arith_ss [E4,lam_size_def
(*
val (SUBST_eqns, SUBST_ind) = 
Defn.tprove(SUBST_defn,
REWRITE_TAC [SYM Fn_def]
  THEN TotalDefn.TC_SIMP_TAC [] []
  THEN REWRITE_TAC [DECIDE `x<1+y = x<=y`]
  THEN REPEAT GEN_TAC THEN DISCH_THEN (K ALL_TAC)
  THEN Q.SPEC_TAC (`[]`,`Scope`)
  THEN Q.ID_SPEC_TAC `v` THEN Q.ID_SPEC_TAC `v'`
  THEN measureInduct_on `lam_size (\v. 0) (\v. 0) M`
  THEN Cases_on `M` THEN RW_TAC std_ss [] THENL
  [RW_TAC arith_ss [E1,lam_size_def,Var_subst],
   RW_TAC arith_ss [E2,lam_size_def],
   RW_TAC arith_ss [E3,lam_size_def,
      DECIDE `x<=p /\ y<=q ==> x+(y+1) <= p+(q+1)`],
   MP_TAC (Q.INST [`x:'a`    |-> `v:'a`,
                  `Q:^lamty` |-> `Var v':^lamty`,
                  `v:'a`     |-> `a:'a`,
                  `M:^lamty` |-> `l:^lamty`, 
                `Sc:'a list` |-> `Scope:'a list`] E4)
    THEN RW_TAC std_ss [] THENL
    [RW_TAC arith_ss [],
     Q.PAT_ASSUM `!y. Z y` 
         (fn th => MP_TAC (Q.SPEC `l` th) THEN ASSUME_TAC th)
       THEN REWRITE_TAC [lam_size_def,DECIDE`x<1+x`,
                         arithmeticTheory.ADD_CLAUSES] THEN DISCH_TAC
       THEN Q.PAT_ASSUM `x ==> y` MP_TAC
       THEN RW_TAC arith_ss [lam_size_def,DECIDE`x<y+1 = x<=y`]
       THEN POP_ASSUM (K ALL_TAC)
       THEN RULE_ASSUM_TAC 
            (REWRITE_RULE [lam_size_def,DECIDE`x<1+y = x<=y`,
                    arithmeticTheory.ADD_CLAUSES])
       THEN PROVE_TAC [DECIDE `!x y z. x<=y /\ y<=z ==> x<=z`],
     Q.PAT_ASSUM `!y. Z y` 
         (fn th => MP_TAC (Q.SPEC `l` th) THEN ASSUME_TAC th)
       THEN REWRITE_TAC [lam_size_def,DECIDE`x<1+x`,
                         arithmeticTheory.ADD_CLAUSES]
       THEN DISCH_TAC THEN Q.PAT_ASSUM `x ==> y` MP_TAC
       THEN RW_TAC arith_ss [lam_size_def,DECIDE`x<y+1 = x<=y`]]]);
*)
