(*---------------------------------------------------------------------------
 * CONDITIONAL EXPRESSIONS AND THEIR NORMALIZATION (Boyer and Moore).
 *---------------------------------------------------------------------------*)

app load ["tflLib", "QLib"];
open tflLib RW arithTools;

(*---------------------------------------------------------------------------
 * Define the datatype of conditional expressions.
 *---------------------------------------------------------------------------*)
val tyinfo = Datatype.primHol_datatype 
               (Datatype.theFactBase())
                 `cond = A of ind
                       | IF of cond => cond => cond`;

val _ = Datatype.add tyinfo;
fun condDefine s q =
  new_recursive_definition {name = s, fixity = Prefix,
                            rec_axiom = Facts.axiom_of tyinfo, def=Term q};

(*---------------------------------------------------------------------------
 * Now a measure function for termination, due to Robert Shostak. 
 *---------------------------------------------------------------------------*)
val Meas_DEF = condDefine "Meas"
     `(Meas (A i) = 1) /\
      (Meas (IF x y z) = Meas x + (Meas x * Meas y) + (Meas x * Meas z))`;

(*---------------------------------------------------------------------------
 * The definition of a normalization function
 *---------------------------------------------------------------------------*)
val norm_def = Rfunction "norm"
  `measure Meas`
  `(norm (A i) = A i) /\
   (norm (IF (A x) y z) = IF (A x) (norm y) (norm z)) /\
   (norm (IF (IF u v w) y z) = norm (IF u (IF v y z) (IF w y z)))`;



(*---------------------------------------------------------------------------
 *  Required lemma for termination.  
 *---------------------------------------------------------------------------*)
val Meas_POSITIVE = BETA_RULE
 (Q.prove`!c. (\v. 0 < Meas v) c`
  (MATCH_MP_TAC (Facts.induction_of tyinfo) THEN BETA_TAC THEN 
   RW_TAC[Meas_DEF] THEN ARITH_TAC));


val distribs = [arithmeticTheory.LEFT_ADD_DISTRIB,
                arithmeticTheory.RIGHT_ADD_DISTRIB];
val assocs = [GSYM(arithmeticTheory.ADD_ASSOC),
              GSYM(arithmeticTheory.MULT_ASSOC)];
val LESS_MULT2 = arithmeticTheory.LESS_MULT2;


val norm_terminates = save_thm("norm_terminates", 
prove_termination norm_def
(RW_TAC(Meas_DEF::(assocs@distribs))
  THEN REPEAT CONJ_TAC THEN REPEAT GEN_TAC
  THENL
    [ ALL_TAC, ALL_TAC,
      let val [Pu,Py,Pz] = map (Lib.C Q.SPEC Meas_POSITIVE) [`u`,`y`,`z`]
      in MP_TAC(LIST_CONJ(map (MATCH_MP LESS_MULT2 o LIST_CONJ) 
                              [[Pu,Py],[Pu,Pz]])) end]
  THEN ARITH_TAC));


val norm_eqns = save_thm("norm_eqns",
  RW_RULE [norm_terminates] (#rules norm_def));

val norm_induction = save_thm("norm_induction",
  RW_RULE [norm_terminates] (DISCH_ALL (#induction norm_def)));


(*---------------------------------------------------------------------------
 * Define it again, using a lexicographic combination of relations. This is 
 * the version given in the Boyer-Moore book.
 *---------------------------------------------------------------------------*)
val TDEPTH_DEF = 
  condDefine "TDEPTH"
     `(TDEPTH (A i) = 0) /\
      (TDEPTH (IF x y z) = TDEPTH x + 1)`;

val Weight_def = 
    condDefine "Weight"
     `(Weight (A i) = 1) /\
      (Weight (IF x y z) = Weight x * (Weight y + Weight z))`;

val Weight_positive = BETA_RULE
 (Q.prove`!c. (\v. 0 < Weight v) c`
  (MATCH_MP_TAC (Facts.induction_of tyinfo) THEN BETA_TAC THEN 
   RW_TAC(Weight_def::distribs) THEN CONJ_TAC 
   THENL
   [CONV_TAC ARITH_CONV,
   REPEAT GEN_TAC THEN DISCH_THEN 
    (fn th => MATCH_MP_TAC(arithmeticTheory.LESS_IMP_LESS_ADD) THEN 
              MATCH_MP_TAC LESS_MULT2 THEN MP_TAC th)
   THEN RW_TAC[]]));


val point_to_prod_def = 
  new_infix_definition
    ("ptp", Term`## (f:'a->'b) (g:'a->'c) x = (f x, g x)`, 400);

(*---------------------------------------------------------------------------*
 * Notice how the lexicographic combination of measure functions gets made.  *
 * It gets proved wellfounded automatically. It might be handy to have a     *
 * combinator for this case; something like                                  *
 *                                                                           *
 *    f1 XX f2 XX ... XX fn = inv_image ($< ** ...** ($< ** $<))             *
 *                                      (f1## ... ##fn)                      *
 *---------------------------------------------------------------------------*)

local open tflLib
in
fun Rdefine thml = 
rfunction (Tfl.postprocess{WFtac = WF_TAC[],
                       terminator = terminator, 
                       simplifier = tc_simplifier thml}
             (Datatype.theFactBase()))
          RW_RULE
end;


val Ndef = Rdefine[point_to_prod_def] "inv_image"
  `inv_image ($< ** $<) (Weight##TDEPTH)`
  `(N(A i) = A i) /\
   (N(IF(A x) y z) = IF(A x) (N y) (N z)) /\
   (N(IF(IF u v w) y z) = N(IF u (IF v y z) (IF w y z)))`;


val MULTS = arithmeticTheory.MULT_CLAUSES;
val ADDS = arithmeticTheory.ADD_CLAUSES;

val Nterminates = save_thm("Nterminates",
prove_termination Ndef
(REPEAT CONJ_TAC
 THENL
 [ REPEAT GEN_TAC THEN DISJ1_TAC THEN RW_TAC[Weight_def,MULTS,ADDS] 
    THEN MP_TAC (Q.SPEC`y` Weight_positive) THEN CONV_TAC ARITH_CONV,
   REPEAT GEN_TAC THEN DISJ1_TAC THEN RW_TAC[Weight_def,MULTS,ADDS] 
    THEN MP_TAC (Q.SPEC`z` Weight_positive) THEN CONV_TAC ARITH_CONV,
   REPEAT GEN_TAC THEN DISJ2_TAC 
    THEN RW_TAC([Weight_def,MULTS,ADDS]@distribs@assocs) THEN CONJ_TAC
    THENL[CONV_TAC ARITH_CONV, RW_TAC[TDEPTH_DEF]] THEN CONV_TAC ARITH_CONV]));

val Neqns = save_thm("Neqns",  RW_RULE [Nterminates] (#rules Ndef));

val Ninduction = save_thm("Ninduction",
  RW_RULE [Nterminates] (DISCH_ALL (#induction Ndef)));
