(*---------------------------------------------------------------------------
           McCarthy's 91 function.
 ---------------------------------------------------------------------------*)

app load ["TotalDefn", "BasicProvers", "SingleStep", "numLib"];

open TotalDefn BasicProvers SingleStep numLib;
infix 8 by;

(*---------------------------------------------------------------------------
         Miscellaneous support
 ---------------------------------------------------------------------------*)

val measure_def   = prim_recTheory.measure_def;
val inv_image_def = relationTheory.inv_image_def;

val arith_ss  = simpLib.++(bool_ss,arithSimps.ARITH_ss);
fun ARITH q  = ARITH_PROVE (Term q);

val lem = ARITH `~(x > 100) ==> (101-y < 101-x = x<y)`;


(*---------------------------------------------------------------------------
       Define the function. Note that the definition of nested
       functions is achieved by the invisible definition of an
       "auxiliary" function from which the desired function is 
       obtained. The termination proof of the desired function
       is reduced to that of the auxiliary function.

       We use Hol_defn to make the definition, since we tackle
       the termination proof ourselves.
 ---------------------------------------------------------------------------*)

val N_defn = 
 Defn.Hol_defn "N" 
               `N(x) = if x>100 then x-10 else N (N (x+11))`;


val Neqn = Defn.eqns_of N_defn;

(*---------------------------------------------------------------------------
      Prove partial correctness for N, to see how such a proof
      works when the termination relation has not yet been supplied.
 ---------------------------------------------------------------------------*)

local val SOME Nind = Defn.ind_of N_defn
in
val Npartly_correct = Q.prove(
 `WF R /\ 
  (!x. ~(x > 100) ==> R (N_aux R (x + 11)) x) /\
  (!x. ~(x > 100) ==> R (x + 11) x)
     ==>
  !n. N(n) = if n>100 then n-10 else 91`,
STRIP_TAC THEN recInduct Nind
  THEN REPEAT STRIP_TAC
  THEN ONCE_REWRITE_TAC [Neqn]
  THEN RW_TAC bool_ss []
  THEN Q.PAT_ASSUM `COND p q r > s` MP_TAC
  THEN ASM_REWRITE_TAC[] 
  THENL [NTAC 2 (POP_ASSUM MP_TAC), ALL_TAC]
  THEN ARITH_TAC)
end;


(*---------------------------------------------------------------------------
      Prove termination for N. The termination relation is 
      "measure \x. 101 - x". Termination is shown by proving
      the termination of N_aux.  We also make another definition, 
      Nine1, just for readability in the proofs.
 ---------------------------------------------------------------------------*)

val Nine1_def = Define `Nine1 = N_aux (measure \x. 101 - x)`;

val Nine1_def' = BETA_RULE(REWRITE_RULE [measure_def,inv_image_def] Nine1_def);


local val SOME N_aux_defn = Defn.aux_defn N_defn
in
val eqn0 = 
  REWRITE_RULE [prim_recTheory.WF_measure]
    (BETA_RULE
      (Q.INST [`R` |-> `measure \x. 101 - x`]
            (Defn.eqns_of N_aux_defn)))
end;


val eqn1 = Q.prove(`!x. x>100 ==> (Nine1 x = x-10)`,
RW_TAC bool_ss [Nine1_def] 
  THEN MP_TAC eqn0 
  THEN RW_TAC bool_ss []);


val eqn2 = Q.prove
(`!x. ~(x>100) /\ x < Nine1 (x+11) 
         ==> 
     (Nine1 x = Nine1 (Nine1 (x+11)))`,
 RW_TAC bool_ss [Nine1_def] 
   THEN MP_TAC eqn0 THEN RW_TAC bool_ss []
   THEN FIRST_ASSUM (MATCH_MP_TAC o REWRITE_RULE [AND_IMP_INTRO])
   THEN CONJ_TAC THENL [ALL_TAC, Q.PAT_ASSUM `x<y` MP_TAC]
   THEN RW_TAC arith_ss [lem,measure_def,inv_image_def]);


(*---------------------------------------------------------------------------

        Termination proof. It's worth noting that termination is
        completely independent of the partial correctness proof.
        [People used to think that termination and partial correctness
         of nested recursive functions had to be proved simultaneously;
         instead of unrolling the function definition for the nested 
         application of the function, the partial correctness would be
         used instead. The fear was that it would be circular to expand
         the function definition before the termination of the function 
         was proved, i.e., to use it in its own termination proof. This 
         isn't so: one can validly unroll the definition at "sufficiently 
         small" instances. Often this suffices to get the proof, and, in
         such cases, partial correctness need not be involved.]

 ---------------------------------------------------------------------------*)

val (N_eqn, N_ind) = Defn.tprove(N_defn,
WF_REL_TAC `measure \x. 101 - x`  (* non-nested TC proved here *)
  THEN RW_TAC bool_ss [GSYM Nine1_def', lem]
  THEN measureInduct_on `(\m. 101 - m) x`
  THEN RW_TAC bool_ss [] 
  THEN IMP_RES_THEN (fn th => RULE_ASSUM_TAC (REWRITE_RULE [th])) lem
  THEN Cases_on `x+11 > 100` THENL
   [RW_TAC arith_ss [eqn1],
    `x+11 < Nine1((x+11)+11)`  by PROVE_TAC [ARITH `x<x+11`]            THEN
    `x < Nine1((x+11)+11) -11` by PROVE_TAC [ARITH `x+y < z = x < z-y`] THEN
    `Nine1 (x+11) = Nine1 (Nine1 ((x+11) + 11))` by PROVE_TAC[eqn2]     THEN
    Cases_on `Nine1((x+11)+11) -11 > 100` THENL
     [`Nine1((x+11)+11) > 100` by PROVE_TAC [ARITH`x-11 > 100 ==> x>100`] THEN
      `Nine1(Nine1((x+11)+11)) = Nine1((x+11)+11) -10` by PROVE_TAC[eqn1] 
        THEN Q.PAT_ASSUM `Nine1 (x+11) = M` (SUBST_ALL_TAC o SYM) 
        THEN PROVE_TAC[ARITH`x>100 ==> x-11<x-10`,arithmeticTheory.LESS_TRANS],
      RES_TAC 
       THEN IMP_RES_THEN SUBST_ALL_TAC (ARITH`w+11<y ==> ((y-11)+11 = y)`)
       THEN PROVE_TAC [eqn2,arithmeticTheory.LESS_TRANS]]]);
 


(*---------------------------------------------------------------------------
      Note that the above development is slightly cranky, since
      the partial correctness theorem has constraints remaining. 
      These were actually addressed by the termination proof, but 
      the witnesses were proved and then thrown away. 

      Now try some computations with N.
 ---------------------------------------------------------------------------*)

local open computeLib arithmeticTheory
      val compset = reduceLib.reduce_rws()
      val _ = add_thms [GREATER_EQ, GREATER_DEF] compset
      val _ = add_thms [N_eqn] compset
in
val Eval = Count.apply (CBV_CONV compset) o Term
end;

Eval `N 0`;
Eval `N 10`;
Eval `N 11`;
Eval `N 12`;
Eval `N 40`;
Eval `N 89`;
Eval `N 90`;
Eval `N 99`;
Eval `N 100`;
Eval `N 101`;
Eval `N 102`;
Eval `N 127`;
