(*---------------------------------------------------------------------------

    An example taken from

        "Automating induction over mutually recursive functions",
        Deepak Kapur and M. Subramaniam, Proceedings of AMAST'96,
        Springer LNCS 1101.

    The example displays the equivalence of call-by-name and 
    call-by-value evaluation strategies for a type of simple arithmetic 
    terms. The evaluation functions feature mutual and nested recursions.

 ---------------------------------------------------------------------------*)

load "bossLib"; open bossLib; infix 8 by;


Hol_datatype `arith = Con of num                   (* constant *)
                    | Var of 'a                    (* variable *)
                    | Add of arith => arith        (* addition *)
                    | App of arith => 'a => arith  (* func. application *)`;


(*---------------------------------------------------------------------------
                     Call-by-name evaluation
 ---------------------------------------------------------------------------*)

val CBN_defn = 
 Hol_defn 
  "CBN"
    `(CBN (Con n, y, N)     = Con n)
 /\  (CBN (Var x, y, N)     = if x=y then CBNh N else Var x)
 /\  (CBN (Add a1 a2, y, N) = Add (CBN(a1,y,N)) (CBN(a2,y,N)))
 /\  (CBN (App B v M, y, N) = CBN (CBN(B,v,M), y, N))

 /\  (CBNh (Con n) = Con n)
 /\  (CBNh (Var x) = Var x)
 /\  (CBNh (Add a1 a2) = Add (CBNh a1) (CBNh a2))
 /\  (CBNh (App B v M) = CBN (B,v,M))`;


val CBN_eqns     = Defn.eqns_of CBN_defn;
val SOME CBN_ind = Defn.ind_of CBN_defn;


(*---------------------------------------------------------------------------
                     Environment lookup
 ---------------------------------------------------------------------------*)

val lookup_def = 
 Define
     `(lookup x [] = 0)
  /\  (lookup x ((y,N)::rst) = if x=y then N else lookup x rst)`;


(*---------------------------------------------------------------------------
                     Call-by-value evaluation
 ---------------------------------------------------------------------------*)

val CBV_defn = Hol_defn "CBV"
    `(CBV (Con n, env)  = n)
 /\  (CBV (Var x, env)  = lookup x env)
 /\  (CBV (Add a1 a2, env) = CBV(a1,env) + CBV(a2,env))
 /\  (CBV (App B v M, env) = CBV(B, (v,CBV(M,env))::env))`;


val CBV_eqns      = Defn.eqns_of CBV_defn;
val SOME CBV_ind  = Defn.ind_of CBV_defn;


(*---------------------------------------------------------------------------
          Some munging to get tidier defns and induction thms.
 ---------------------------------------------------------------------------*)

val SOME R  = Defn.reln_of CBN_defn;
val SOME R1 = Defn.reln_of CBV_defn;

val CBNTerminates_def =
 Define
    `CBNTerminates ^R = ^(list_mk_conj (Defn.tcs_of CBN_defn))`;

val CBVTerminates_def =
 Define
    `CBVTerminates ^R1 = ^(list_mk_conj (Defn.tcs_of CBV_defn))`;

val TC0thms = 
 CONJUNCTS 
   (EQ_MP (SPEC_ALL CBNTerminates_def)
        (ASSUME (lhs(concl(SPEC_ALL CBNTerminates_def)))));

val TC1thms = 
 CONJUNCTS 
   (EQ_MP (SPEC_ALL CBVTerminates_def)
        (ASSUME (lhs(concl(SPEC_ALL CBVTerminates_def)))));

val CBN_eqns' = itlist PROVE_HYP TC0thms CBN_eqns;
val CBN_ind'  = itlist PROVE_HYP TC0thms CBN_ind;
val CBV_eqns' = itlist PROVE_HYP TC1thms CBV_eqns;


(*---------------------------------------------------------------------------
    Partial correctness is phrased as the following:

       CBNTerminates R /\ CBVTerminates R1 
           ==>
       !x y N env. 
           CBV (CBN(x,y,N),env) 
              = 
           CBV (x, (y,CBV(N,env))::env)

    We build an induction theorem first, by instantiating the 
    induction theorem for CBN to the predicates suggested by the 
    method of Richard Boulton.
 ---------------------------------------------------------------------------*)

val tm = Term
 `\(x,y,N). 
    !env. CBV (CBN (x,y,N), env) 
               = 
          CBV(x, (y, CBV(N,env))::env)`;

val tm1 = Term `\N. !env. CBV (CBNh N, env) = CBV(N,env)`;

val [P0,P1] = fst(strip_forall(concl CBN_ind'));
val ind0 = SPEC_ALL CBN_ind';
val ind1 = CONV_RULE (DEPTH_CONV Let_conv.GEN_BETA_CONV)
              (Rsyntax.INST [P0 |-> tm, P1 |-> tm1] ind0);
val [ind2a, _] = CONJUNCTS (UNDISCH ind1);
val ind3 = REWRITE_RULE pairTheory.pair_rws
             (SPEC (Term`(x,y,N):'a arith#'a#'a arith`) ind2a);
val ind4 = DISCH (fst(dest_imp(concl ind1)))
                 (SPEC_ALL ind3);


(*---------------------------------------------------------------------------
      We distinguish the names "R" and "R1" of the termination 
      relations; otherwise, the termination relations for both 
      CBN and CBV would have the same name, which is confusing.
 ---------------------------------------------------------------------------*)

val R1' = mk_var("R1", type_of R1);;

val CBV_eqns'' = UNDISCH (Rsyntax.INST [R1 |-> R1'] (DISCH_ALL CBV_eqns'));


(*---------------------------------------------------------------------------
           Given ind4, the correctness proof is easy!
 ---------------------------------------------------------------------------*)

val KapurSubra = store_thm("KapurSubra",
 Term`!R R1.
        CBNTerminates ^R /\ CBVTerminates ^R1'
            ==>
        !x (y:'a) N env. 
                CBV (CBN(x,y,N),env) 
                   = 
                CBV (x, (y, CBV(N,env))::env)`,
REPEAT GEN_TAC THEN STRIP_TAC
 THEN MATCH_MP_TAC ind4
 THEN RW_TAC std_ss [CBN_eqns',CBV_eqns'',lookup_def]);


(*---------------------------------------------------------------------------
       Remaining task: supply R and R1 such that 

          CBNTerminates R /\ CBVTerminates R1

       This will require reasoning about the auxiliary functions
       used in defining CBN and CBV. Since CBN is mutual recursive 
       with CBNh and also a nested recursion, this makes things 
       not completely obvious. We tackle CBV first, since it's
       easier.
 ---------------------------------------------------------------------------*)

local val [guess1,_] = TotalDefn.guessR CBV_defn
in
val CBVTerminates = Q.prove
(`CBVTerminates ^guess1`,
 TotalDefn.TC_SIMP_TAC [] [prim_recTheory.WF_measure,CBVTerminates_def])
end;


(*---------------------------------------------------------------------------
    Termination of CBN is much harder! First there's the problem of 
    seeing why it terminates, then there's the task of expressing
    termination formally, and finally, the proof of termination is 
    also a little bit challenging.

    1. Why it terminates. CBN eliminates all "App" nodes in an
       expression. Thus, we can simply count the Apply nodes. 
       However, the case where there are no App nodes in the 
       expression must also be dealt with; in this case, the 
       recursion is always on smaller arguments.

    2. Formal expression. The specification of CBN leads to several 
       background definitions: a "union" function is constructed to 
       handle the mutual recursion, and an "auxiliary" function is
       built to handle the nested recursion. The termination constraints
       that we have to eliminate are those of the union function, and
       these contain occurrences of the auxiliary function.

       The way the analysis of (1) is formalized starts by defining 
       a count of the App nodes in an expression. This measure will 
       be put together in a lexicographic combination with the normal
       size functions. We just have to fiddle to make sure that 
       different size functions are applied in the different injections
       into the sum.
      
 ---------------------------------------------------------------------------*)

val AppCount_def =
  Define
    `(AppCount (App B _ M) = 1 + AppCount B + AppCount M)
 /\  (AppCount (Add M N)   = AppCount M + AppCount N)
 /\  (AppCount   _____     = 0)`;

val JointAppCount_def =
  Define
    `(JointAppCount (INL(P,_,Q)) = AppCount P + AppCount Q)
 /\  (JointAppCount (INR P)      = AppCount P)`;

val JointStdSize_def =
  Define 
   `(JointStdSize (INL(P,_,Q)) = arith_size (\v.0) P + arith_size (\v.0) Q)
 /\ (JointStdSize (INR P)      = arith_size (\v.0) P)`;


val R = Term 
`inv_image ($< LEX $<)
  (\x:'a arith#'a#'a arith + 'a arith. (JointAppCount x, JointStdSize x))`;

val Fn_def = 
 Define 
  `Fn = CBN_UNION_aux ^R`;

val SOME Un = Defn.union_defn CBN_defn;
val SOME Aux = Defn.aux_defn Un;
val AuxR = Defn.set_reln Aux R;

val basic_tcs = prove(list_mk_conj (Defn.tcs_of AuxR),
 TotalDefn.TC_SIMP_TAC [] 
    [JointAppCount_def,JointStdSize_def,
     AppCount_def,definition "arith_size_def"]);

val AuxR' = Defn.elim_tcs AuxR (CONJUNCTS basic_tcs);
val eqns  = REWRITE_RULE [basic_tcs,SYM Fn_def] (Defn.eqns_of AuxR');
val ind   = REWRITE_RULE [basic_tcs,SYM Fn_def] (valOf(Defn.ind_of AuxR'));
val (_::_::_::eq4::_) = CONJUNCTS eqns;

val AppCount_lem = Q.prove
(`!e. AppCount (Fn e) = 0`,
 recInduct ind
  THEN RW_TAC arith_ss [eqns,AppCount_def]
  THEN Q.SUBGOAL_THEN `inv_image ($< LEX $<) 
        (\x. (JointAppCount x,JointStdSize x))
        (INL (Fn (INL (B,v,M)),y,N)) (INL (App B v M,y,N))` ASSUME_TAC
  THENL
    [Q.PAT_ASSUM `x ==> y` (K ALL_TAC)
       THEN TotalDefn.TC_SIMP_TAC []
               [JointAppCount_def, JointStdSize_def,AppCount_def],
     PROVE_TAC [eq4]]);


val (CBN_eqns, CBN_ind) = 
Defn.tprove(CBN_defn,
  EXISTS_TAC R THEN REWRITE_TAC [SYM Fn_def, basic_tcs]
   THEN TotalDefn.TC_SIMP_TAC [] 
          [JointAppCount_def,JointStdSize_def, AppCount_def,AppCount_lem]);


