(*---------------------------------------------------------------------------
         Substitution in the named lambda calculus
 ---------------------------------------------------------------------------*)

app load ["bossLib", "stringTheory", "pred_setTheory"]; 
open bossLib listTheory stringTheory pred_setTheory;

(* Should already happen! *)

val _ = computeLib.add_funs [IMPLODE_EQNS, EXPLODE_EQNS]; 

Hol_datatype `lam = Var  of string
                  | Comb of lam => lam
                  | Abs  of string => lam`;

(*---------------------------------------------------------------------------
    The system-generated size definition here doesn't work well, 
    since renaming can increase the size of a term under that 
    definition. Thus we define our own.
 ---------------------------------------------------------------------------*)

val lam_count_def = 
 Define `(lam_count (Var _) = 0) 
    /\   (lam_count (Comb M N) = 1 + lam_count M + lam_count N)
    /\   (lam_count (Abs _ M) = 1 + lam_count M)`;

(* Should be able to use Define for this. *)

val (DEL1,_) = Defn.tprove
 (Hol_defn "DEL1"
    `(DEL1 [] = []) /\
     (DEL1 (h::t) = if x=h then t else h::DEL1 t)`,
  WF_REL_TAC `measure LENGTH` THEN RW_TAC list_ss []);


(*   Unused lemmas 

val MEM_DISTINCT = Q.prove
(`!l a b. MEM a l /\ ~MEM b l ==> ~(a=b)`,
 Induct THEN RW_TAC list_ss [] THEN PROVE_TAC[]);

val MEM_DEL1 = Q.prove
(`!A x s. ~(x=s) ==> (MEM x (DEL1 s A) = MEM x A)`,
 Induct THEN RW_TAC std_ss [MEM,DEL1]);;

val MEM_DEL1_MEM = Q.prove
(`!A x. MEM x (DEL1 x A) ==> MEM x A`,
 Induct THEN RW_TAC std_ss [MEM,DEL1]);;

val MEM_FV_MONO = Q.prove
(`!M A x. MEM x A ==> MEM x (FV M A)`,
Induct THEN RW_TAC std_ss [FV,MEM] THEN PROVE_TAC[MEM_DISTINCT,MEM_DEL1]);
*)

val FV = 
 Define 
    `(FV (Var x) = {x}) 
 /\  (FV (Comb M N) = (FV M) UNION (FV N))
 /\  (FV (Abs v M)  = FV M DELETE v)`;

val STRCAT =
   Define `STRCAT s1 s2 = IMPLODE (APPEND (EXPLODE s1) (EXPLODE s2))`;

val PRIME = 
   Define `PRIME s = STRCAT s "'"`;

val AWAY_defn =
  Hol_defn "AWAY" 
     `AWAY x = if FINITE set 
               then if x IN set then AWAY (PRIME x) else x
               else ARB`;

val isPrefix =
   Define `isPrefix s1 s2 = ?s3. s2 = STRCAT s1 s3`;

val GSPEC_DEF = Q.prove
(`GSPEC = \f v. ?x. (v,T) = f x`,
 NTAC 2 (CONV_TAC FUN_EQ_CONV THEN GEN_TAC) THEN
 PROVE_TAC [GSPECIFICATION,SPECIFICATION]);

val INTER_MONO = Q.prove
(`!P Q R. Q SUBSET R ==> P INTER Q SUBSET P INTER R`,
 RW_TAC std_ss [SUBSET_DEF, SPECIFICATION,INTER_DEF,GSPEC_DEF]);

val STRCAT_ASSOC = Q.prove
(`!s1 s2 s3. STRCAT s1 (STRCAT s2 s3) = STRCAT (STRCAT s1 s2) s3`,
 RW_TAC std_ss [STRCAT,IMPLODE_11,EXPLODE_IMPLODE,APPEND_ASSOC]);

val STRCAT_11 = Q.prove
(`!s1 s2 s3. (STRCAT s1 s2 = STRCAT s1 s3) = (s2=s3)`,
 RW_TAC std_ss [STRCAT,IMPLODE_11,EXPLODE_11,APPEND_11]);

val STRCAT_ACYCLIC = Q.prove
(`!s1 s2. (s1 = STRCAT s1 s2) = (s2 = "")`,
 RW_TAC std_ss [] THEN 
 `!x. x = STRCAT x ""` by RW_TAC list_ss [STRCAT,EXPLODE_EQNS,IMPLODE_EXPLODE]
 THEN PROVE_TAC [STRCAT_11]);


(*---------------------------------------------------------------------------
   Termination of AWAY. The function terminates because the set of 
   strings in L that x is a prefix of decreases with each recursive call.
 ---------------------------------------------------------------------------*)

val (AWAY,AWAY_IND) = Defn.tprove
(AWAY_defn,
 WF_REL_TAC `measure \x. CARD (combin$C $IN set INTER isPrefix x)`
   THEN RW_TAC std_ss [SPECIFICATION,combinTheory.C_DEF]
   THEN MATCH_MP_TAC (Ho_Rewrite.REWRITE_RULE 
          [GSYM RIGHT_FORALL_IMP_THM,AND_IMP_INTRO] CARD_PSUBSET) 
   THEN CONJ_TAC THENL 
   [PROVE_TAC [INTER_FINITE],
    RW_TAC std_ss [PSUBSET_DEF,SUBSET_DEF,INTER_DEF,GSPEC_DEF,SPECIFICATION]
    THENL [POP_ASSUM MP_TAC THEN REPEAT (POP_ASSUM (K ALL_TAC))
            THEN RW_TAC std_ss [isPrefix,PRIME]
            THEN Q.EXISTS_TAC `STRCAT "'" s3` 
            THEN RW_TAC std_ss [STRCAT_ASSOC],
      CONV_TAC (DEPTH_CONV FUN_EQ_CONV)
        THEN RW_TAC std_ss [isPrefix]
        THEN Q.EXISTS_TAC `x`
        THEN RW_TAC std_ss [PRIME,GSYM STRCAT_ASSOC,STRCAT_11,STRCAT_ACYCLIC]
        THEN RW_TAC list_ss [STRCAT,EXPLODE_EQNS,IMPLODE_EQNS]]]);


(*---------------------------------------------------------------------------
     Finally, the substitution algorithm
 ---------------------------------------------------------------------------*)

val subst_defn =
 try Defn.Hol_Rdefn 
   "subst"   
   `measure (lam_count o SND)`

   `(subst (x,Q) (Var v)    = if x=v then Q else Var v)
 /\ (subst (x,Q) (Comb M N) = Comb (subst (x,Q) M) (subst (x,Q) N))
 /\ (subst (x,Q) (Abs v M)  = 
     if x=v then Abs v M else
     if x IN (FV M) /\ v IN (FV Q)   (* capture would happen *)
     then let v' = AWAY (FV M UNION FV Q) v
          in Abs v' (subst (x,Q) (subst (v,Var v') M))
     else Abs v (subst (x,Q) M))`;

(*---------------------------------------------------------------------------
     Termination is hard to prove, because of nestedness. Also
     because proving tcs in defns is currently very clumsy.
 ---------------------------------------------------------------------------*)

val subst_defn1 = 
  Defn.simp_tcs subst_defn 
   (REWRITE_CONV [prim_recTheory.WF_measure] THENC
    TotalDefn.TC_SIMP_CONV 
        [combinTheory.o_DEF,prim_recTheory.measure_def,
         relationTheory.inv_image_def,lam_count_def] THENC
    REWRITE_CONV [DECIDE `!x u. x < 1 + x`,
                  DECIDE `!x u. x < 1 + x + u`,
                  DECIDE `!x u. x < 1 + u + x`]);

(* The following should be made to work better:

     Defn.prove_tcs subst_defn (TotalDefn.TC_SIMP_TAC [] []);
*)

val SOME aux_defn = Defn.aux_defn subst_defn1;
val aux_defn1 = 
  Defn.simp_tcs aux_defn 
   (REWRITE_CONV [prim_recTheory.WF_measure] THENC
    TotalDefn.TC_SIMP_CONV 
        [combinTheory.o_DEF,prim_recTheory.measure_def,
         relationTheory.inv_image_def,lam_count_def] THENC
    REWRITE_CONV [DECIDE `!x u. x < 1 + x`,
                  DECIDE `!x u. x < 1 + x + u`,
                  DECIDE `!x u. x < 1 + u + x`]);

val aux_eqns1 = Defn.eqns_of aux_defn1;
val aux_eqns2 = CONV_RULE
  (REWRITE_CONV [prim_recTheory.WF_measure] THENC
    TotalDefn.TC_SIMP_CONV 
        [combinTheory.o_DEF,prim_recTheory.measure_def,
         relationTheory.inv_image_def,lam_count_def] THENC
    REWRITE_CONV [DECIDE `!x u. x < 1 + x`,
                  DECIDE `!x u. x < 1 + x + u`,
                  DECIDE `!x u. x < 1 + u + x`]) aux_eqns1;


val [E1,E2,E3] = CONJUNCTS aux_eqns2;

(*---------------------------------------------------------------------------
    Finally, the termination proof. It's easy except for the 
    nested call case.
 ---------------------------------------------------------------------------*)

val (subst_eqns, subst_ind) = Defn.tprove
(subst_defn1,
 RW_TAC std_ss [DECIDE `x<1+y = x<=y`]
  THEN REPEAT (POP_ASSUM (K ALL_TAC))  (* Generalize *)
  THEN Q.SPEC_TAC (`AWAY (FV M UNION FV Q) v`, `w`)
  THEN Q.ID_SPEC_TAC `v`
  THEN measureInduct_on `lam_count M` 
  THEN Cases_on `M` THEN POP_ASSUM MP_TAC
  THEN RW_TAC arith_ss [lam_count_def,
         DECIDE`x<y+1 = x<=y`, DECIDE `y<p+(q+1) = y<=p+q`] THENL
  [RW_TAC arith_ss [lam_count_def,aux_eqns2],  (* Var *)
   RW_TAC arith_ss [lam_count_def,aux_eqns2,   (* Comb *)
           DECIDE `x<=p /\ y<=q ==> x+(y+1) <= p+(q+1)`],
   MP_TAC (Q.INST  (* Abs: start with delicate instantiation *)
      [`M:lam`    |-> `l:lam`,
       `v:string` |-> `s:string`,
       `x:string` |-> `v:string`,
       `Q:lam`    |-> `Var w:lam`] (SPEC_ALL E3))   (* INST problem *)
      THEN RW_TAC std_ss [DECIDE `x<1+y = x<=y`] 
      THENL [RW_TAC arith_ss [lam_count_def], ALL_TAC,
             RW_TAC arith_ss [lam_count_def]]
      (* tricky nested case. 
           In order to unroll subst_tupled_aux in the goal, 
           we first prove the nested TC, by using the inductive hypothesis. *)
      THEN `!v w. lam_count
                    (subst_tupled_aux(\x y. lam_count(SND x)<lam_count(SND y))
                      ((v,Var w),l)) <= lam_count l`
           by RW_TAC arith_ss []
      THEN RW_TAC arith_ss [lam_count_def]   (* unroll subst_tuple_aux *)
      THEN Q.PAT_ASSUM `p ==> q` (K ALL_TAC) (* g.c. *)
      THEN PROVE_TAC [LESS_EQ_TRANS]]);


restart();;
e (RW_TAC std_ss [DECIDE `x<1+y = x<=y`]
  THEN REPEAT (POP_ASSUM (K ALL_TAC))  (* Generalize *)
 THEN Q.SPEC_TAC (`AWAY (FV M UNION FV Q) v`, `w`)
 THEN GEN_TAC 
 THEN `?u. u = Var w` by PROVE_TAC []
 THEN POP_ASSUM (fn th => SUBST1_TAC (SYM th) THEN MP_TAC th)
 THEN Q.ID_SPEC_TAC `w`
 THEN Q.ID_SPEC_TAC `M`
 THEN Q.ID_SPEC_TAC `u`
 THEN Q.ID_SPEC_TAC `v`);

e (recInduct (valOf(Defn.ind_of aux_defn1)));
e (RW_TAC std_ss []);
(*1*)
e (RW_TAC arith_ss [lam_count_def,aux_eqns2]);
(*2*) 
e (RW_TAC arith_ss [lam_count_def,aux_eqns2]);
(*3*)