(*---------------------------------------------------------------------------
           McCarthy's 91 function.
 ---------------------------------------------------------------------------*)

app load ["bossLib", "numLib"];

open TotalDefn bossLib numLib;

val ARITH = DECIDE o Term;

(*---------------------------------------------------------------------------
       Define the 91 function. We call it "N". We use Hol_defn to 
       make the definition, since we tackle the termination proof 
       ourselves.
 ---------------------------------------------------------------------------*)

val N_defn = 
 Hol_defn "N" 
          `N(x) = if x>100 then x-10 else N (N (x+11))`;

val [Neqn] = Defn.eqns_of N_defn;

(*---------------------------------------------------------------------------
      Prove partial correctness for N, to see how such a proof
      works when the termination relation has not yet been supplied.
      The proof is a bit slow because it's short, and because the 
      rewriter invokes the arithmetic decision procedure repeatedly.
 ---------------------------------------------------------------------------*)

local val SOME Nind = Defn.ind_of N_defn
in
val Npartly_correct = Q.prove(
 `WF R /\ 
  (!x. ~(x > 100) ==> R (N_aux R (x + 11)) x) /\
  (!x. ~(x > 100) ==> R (x + 11) x)
     ==>
  !n. N(n) = if n>100 then n-10 else 91`,
STRIP_TAC THEN recInduct Nind
  THEN RW_TAC std_ss [ARITH `x+11-10 = x+1`]
  THEN ONCE_REWRITE_TAC [Neqn]
  THEN RW_TAC arith_ss [])
end;


(*---------------------------------------------------------------------------
      Prove termination for N. Unlike normal recursive functions, this
      is a multi-step process.

      1. Figure out what the termination relation should be. For 91,
         one termination relation that works is "measure \x. 101 - x".
         (Determined by introspection).

      2. Instantiate the defn with the relation, via Defn.set_reln.

      3. Find the termination conditions of (2), via Defn.tcs_of.

      4. Prove the non-nested termination conditions. 

      5. Remove (4) from the result of (2), via Defn.elim_tcs.
         The easy stuff is now finished.

      6. Get the auxiliary definition from the result of (5), 
         via Defn.aux_defn.

      7. Get the equations from (6), via Defn.eqns_of.

      8. The result of (7) is a list of equations, constrained by the 
         nested termination condition. Simplify the nested t.c.
         if possible: this will make the termination proof easier. 
         TotalDefn.TC_SIMP_CONV can be useful here.

      9. Do the termination proof, using Defn.tgoal on (5).

     10. Successful completion on (9) yields a tactic "t". Wrap
         everything up with an invocation of Defn.tprove on (5) and t.
         
 ---------------------------------------------------------------------------*)

(*1,2*)

val N_defn_1 = Defn.set_reln N_defn (Term `measure \x. 101 - x`);

(*3*)

val tcs as [tc1,nested,tc2] = Defn.tcs_of N_defn_1;

(*4*)

val non_nested_tcs = prove
(mk_conj(tc1,tc2),
 TotalDefn.TC_SIMP_TAC [] []);

(*5*)

val N_defn_2 = Defn.elim_tcs N_defn_1 (CONJUNCTS non_nested_tcs);

(*6*)

val SOME aux_defn_2 = Defn.aux_defn N_defn_2;

(*7*)

val E0  = DISCH_ALL (hd (Defn.eqns_of aux_defn_2));
val lem = ARITH `~(x > 100) ==> (101-y < 101-x = x<y)`;
val lem1 = Q.prove(`measure f x y = f x < f y`, 
            RW_TAC std_ss [prim_recTheory.measure_def, 
                           relationTheory.inv_image_def]);
(*8*)

val E1  = GEN_ALL (simpLib.SIMP_RULE std_ss [lem,lem1] E0);

(*---------------------------------------------------------------------------

        Termination of the nested t.c. We'll use NA to abbreviate 
        the instantiated auxiliary function, i.e.,

          `N_aux(measure($- 101))`.

        The proof goes as follows:

        Induct on the termination relation, then tidy up the goal 
        with [lem,lem1]. The goal is now to show

            x < NA (x+11)

        - Unroll the function at x+11. The first step requires manually 
          instantiating the rec. eqn. with `x+11`. As this is constrained
          by the the nested t.c.

            x+11 < NA (x+22),

          we need to prove the nested t.c. This is accomplished by applying
          the ind. hyp. After being proved, the nested t.c. is also added 
          to the hypotheses, since it is used again later. Now  
          conditional rewriting will unroll the function at `x+11`, 
          exposing the nested call of the function. The goal is now to 
          prove 

            x < NA(NA(x+22))

        - At this point we know, by simple arithmetic,

            x < NA (x+22) -11

        - A case split is now made, on 

            NA (x+22) -11 > 100

          The motivation for this is rather mysterious, but let's see
          how the proof proceeds. Consider the two branches coming from
          the case split.  

          A. NA (x+22) -11 > 100, so NA (x+22) > 100 so, by using
             the rec. eqn. at NA(x+22) the goal NA(NA(x+22)) reduces
             to NA(x+22) -10. Our task is now to show that 
             x < NA(x+22) -10, and this is easy.

          B. ~(NA (x+22) -11 > 100), so the i.h. may be used to 
             deliver NA(x+22) -11 < NA(NA(x+22) -11+11)
                                  = NA(NA(x+22))
             After that, the proof of this branch is also easy.

          That finishes the whole proof. Revisiting the mysterious 
          case-split, we see that part A does not use the i.h. since 
          it unrolls NA to the non-recursive case. Part B, in contrast,
          uses the induction hypothesis instead of unrolling NA to the 
          recursive case, which would lead to madness. We see that the
          case split was aimed at making the conclusion of the 
          inductive hypothesis work out, so part B is very smooth. 
   
 ---------------------------------------------------------------------------*)


val (N,N_INDUCT) = Defn.tprove
(N_defn_2,
 GEN_TAC THEN measureInduct_on `(\m. 101 - m) x`
  THEN DISCH_TAC THEN Q.PAT_ASSUM`!y. Q y` MP_TAC THEN RW_TAC std_ss[lem,lem1]
  THEN MP_TAC (Q.SPEC `x+11` E1) THEN RW_TAC std_ss [ARITH `x < x+11-10`]
  THEN `x+11 < N_aux(measure(\x.101-x)) (x+11+11)` by PROVE_TAC[ARITH`x<x+11`]
  THEN RW_TAC std_ss [] 
  THEN WEAKEN_TAC is_imp (* g.c. *)
  THEN Cases_on `N_aux(measure (\x.101-x)) (x+11+11) -11 > 100` THENL
   [WEAKEN_TAC is_forall (* i.h. not used in this branch *)
      THEN `N_aux(measure(\x.101-x)) (x+11+11) > 100` by DECIDE_TAC
      THEN MP_TAC (Q.SPEC `N_aux(measure (\x.101-x)) (x+11+11)` E1)
      THEN RW_TAC std_ss [] THEN DECIDE_TAC,
    PROVE_TAC [arithmeticTheory.LESS_TRANS, 
               ARITH `x+y < p ==> x < p-y`, 
               ARITH `x+y < p ==> ((p-y)+y = p)`]]);

 
(*---------------------------------------------------------------------------
      Note that the above development is slightly cranky, since
      the partial correctness theorem has constraints remaining. 
      These were addressed by the termination proof, but the 
      constraints were proved and then thrown away. 

      Now try some computations with N.
 ---------------------------------------------------------------------------*)

EVAL (Term `N 0`);
EVAL (Term `N 10`);
EVAL (Term `N 11`);
EVAL (Term `N 12`);
EVAL (Term `N 40`);
EVAL (Term `N 89`);
EVAL (Term `N 90`);
EVAL (Term `N 99`);
EVAL (Term `N 100`);
EVAL (Term `N 101`);
EVAL (Term `N 102`);
EVAL (Term `N 127`);
