(*----------------------------------------------------------------------------*
 *                                                                            *
 *       The Dutch National Flag - Term Rewriting Emulation                   *
 *                                                                            *
 *                                                                            *
 *         ML version                                                         *
 *      ----------------                                                      *
 *                                                                            *
 *   datatype colour = R | W | B;                                             *
 *                                                                            *
 *   val cons = curry (op ::);                                                *
 *   infix 3 ##;                                                              *
 *   fun (f##g) (x,y) = (f x, g y);                                           *
 *                                                                            *
 *   fun dnf []          = ([],false)                                         *
 *     | dnf (W::R::rst) = (R::W::rst, true)                                  *
 *     | dnf (B::R::rst) = (R::B::rst, true)                                  *
 *     | dnf (B::W::rst) = (W::B::rst, true)                                  *
 *     | dnf (x::rst)    = (cons x##I)(dnf rst);                              *
 *                                                                            *
 *  fun flag L = let val (alist,changed) = dnf L                              *
 *               in if changed then flag alist else alist                     *
 *               end;                                                         *
 *                                                                            *
 *  flag [R,W,W,B,R,W,W,R,B,B];                                               *
 *---------------------------------------------------------------------------*)
(*
   Invoke 

      hol98 -I <holdir>/src/tfl/examples/kls_list

*)

(*---------------------------------------------------------------------------*
 * Load and open useful libraries and theories.                              *
 *---------------------------------------------------------------------------*)

app load ["bossLib", "tflLib", "permTheory", "QLib"]; 
open bossLib tflLib; 
infix &&; infix 8 by;
open combinTheory listTheory kls_listTheory permTheory;

(*---------------------------------------------------------------------------*
 * Set up an inference counter.                                              *
 *---------------------------------------------------------------------------*)

val meter = Count.mk_meter();


(*---------------------------------------------------------------------------*
 * Define Red, White, and Blue to be the colours.                            *
 *---------------------------------------------------------------------------*)

Hol_datatype `colour = R | W | B`;


(*---------------------------------------------------------------------------*
 * Define a standard combinator for functions and pairs.                     *
 *---------------------------------------------------------------------------*)
val fpair_def =
 Define 
   `## (f:'a->'b) (g:'c->'d) (x,y) = (f x, g y)`;
 
  set_MLname "##_def" "fpair_def"; 
  set_fixity "##" (Infix 400);


(*---------------------------------------------------------------------------*
 * Define the swap function.                                                 *
 *---------------------------------------------------------------------------*)
val Dnf_def_ind = 
Define 
 `(Dnf [] = ([], F))                                     /\ 
  (Dnf (CONS W (CONS R rst)) = (CONS R (CONS W rst), T)) /\ 
  (Dnf (CONS B (CONS R rst)) = (CONS R (CONS B rst), T)) /\ 
  (Dnf (CONS B (CONS W rst)) = (CONS W (CONS B rst), T)) /\ 
  (Dnf (CONS x rst)          = (CONS x##I) (Dnf rst))`;

val Dnf_def = CONJUNCT1 Dnf_def_ind
and Dnf_ind = CONJUNCT2 Dnf_def_ind;


(*---------------------------------------------------------------------------*
 * Define the flag function.                                                 *
 *---------------------------------------------------------------------------*)

val flag_def = 
Define 
 `flag l0 = let (l1, changed) = Dnf l0
              in 
                changed => flag l1 | l1`;

(*---------------------------------------------------------------------------*
 * Get nice-looking equations and induction thm for Dnf.                     *
 *---------------------------------------------------------------------------*)
val Dnf_eqns = save_thm("Dnf",
   let val [dnf_nil,a,b,c,d,e,f,g,h,i] = CONJUNCTS Dnf_def
       val simpl = REWRITE_RULE [dnf_nil,fpair_def,I_THM]
   in 
      LIST_CONJ [dnf_nil, simpl e, simpl h, a,b,c,d,f,g,i]
   end);

val Dnf_induction = save_thm("Dnf_induction",
let val taut = Q.TAUT_CONV
        `^(#ant(dest_imp(#Body(dest_forall((concl Dnf_ind)))))) = 
         P [] /\
         P [B] /\
         P [W] /\
         (!rst. P (CONS W (CONS R rst))) /\
         (!rst. P (CONS B (CONS R rst))) /\
         (!rst. P (CONS B (CONS W rst))) /\
         (!rst. P (CONS B rst) ==> P (CONS B (CONS B rst))) /\
         (!rst. P (CONS B rst) ==> P (CONS W (CONS B rst))) /\
         (!rst. P (CONS W rst) ==> P (CONS W (CONS W rst))) /\
         (!rst. P rst ==> P (CONS R rst))`
 in 
    REWRITE_RULE[taut] Dnf_ind
 end);


(*---------------------------------------------------------------------------*
 * Termination measure function (suggestion of jrh). Earlier I was counting  *
 * the total number of swaps that take place. This way is much simpler.      *
 *---------------------------------------------------------------------------*)

val Weight_def = CONJUNCT1
  (Define
     `(Weight (CONS R rst) = 3 + 2*Weight rst) /\
      (Weight (CONS W rst) = 2 + 2*Weight rst) /\
      (Weight (CONS B rst) = 1 + 2*Weight rst) /\
      (Weight []           = 0)`);

(*---------------------------------------------------------------------------*
 * Some miscellaneous stuff used in later proofs.	                     *
 *---------------------------------------------------------------------------*)

val func_to_prod_lem = Q.prove
`!(x:'a) (y:'b) (z:'a#'b). 
   ((x,y) = (f##g) z) ==> ?p0 p1. ((p0,p1) = z) /\ (x = f p0) /\ (y = g p1)`
(Cases_on `z` THEN RW_TAC bool_ss [fpair_def]);

val prod_fg_var = Q.prove
`!(f:'a->'b) (g:'c->'d) x. (f##g) x = (f(FST x), g(SND x))`
(Cases_on `x` THEN RW_TAC bool_ss [fpair_def]);

val lem = Q.prove 
 `!x M. (x=FST M) /\ ~SND M ==> ((x,F) = M)`
(GEN_TAC THEN Cases THEN RW_TAC bool_ss []);;

val filter_lem = Q.prove
`!P l h t. (filter P l = CONS h t) ==> P h`
(Induct_on `l`
  THEN RW_TAC list_ss [filter_def]
  THEN PROVE_TAC[]);

val append_filters_lem = Q.prove
`!P Q l1 l2 h t. 
   (APPEND (filter P l1) (filter Q l2) = CONS h t) ==> P h \/ Q h`
(REPEAT GEN_TAC 
  THEN Cases_on `filter P l1` 
  THEN RW_TAC list_ss [filter_def]
  THEN PROVE_TAC [filter_lem]);

val flag_ss = bossLib.list_ss && type_rws"colour" && [I_THM];


(*---------------------------------------------------------------------------*
 * Instantiate the definition of `flag' with the termination relation.       *
 *---------------------------------------------------------------------------*)
val flag_def1 = 
  UNDISCH 
  (REWRITE_RULE [WFTheory.WF_measure] 
       (Q.INST [`R'` |-> `measure Weight`] 
               (DISCH_ALL flag_def)));

(*---------------------------------------------------------------------------*
 * Termination proof for flag.                                               *
 *---------------------------------------------------------------------------*)

val flag_terminates = Q.prove
`^(hd(#1(dest_thm(flag_def1))))`
(RW_TAC flag_ss [WFTheory.measure_def,primWFTheory.inv_image_def]
  THEN POP_ASSUM (SUBST_ALL_TAC o EQT_INTRO) THEN POP_ASSUM MP_TAC
  THEN Q.ID_SPEC_TAC`l1` THEN Q.ID_SPEC_TAC`l0`
  THEN INDUCT_THEN Dnf_induction MP_TAC
  THEN RW_TAC flag_ss [Dnf_eqns,Weight_def]  (* 4 remaining subgoals *)
  THEN IMP_RES_THEN MP_TAC func_to_prod_lem
  THEN RW_TAC flag_ss []
  THEN RES_TAC
  THEN RW_TAC arith_ss [Weight_def]);

val [flag_eqn,flag_induction] = CONJUNCTS(PROVE_HYP flag_terminates flag_def1);
val _ = save_thm("flag_eqn", flag_eqn);
val _ = save_thm("flag_induction", flag_induction);
 

(*---------------------------------------------------------------------------*
 * Dnf permutes its input.                                                   *
 *---------------------------------------------------------------------------*)

val Dnf_permutation = Q.prove
`!L. permutation L (FST (Dnf L))`
(PROGRAM_TAC{rules = Dnf_eqns, induction = Dnf_induction}
  THEN RW_TAC flag_ss [permutation_refl]
  THEN ((REWRITE_TAC [permutation_def,filter_def] 
         THEN RW_TAC flag_ss [] THEN NO_TAC)
        ORELSE
         RW_TAC flag_ss [prod_fg_var,permutation_mono]));


(*---------------------------------------------------------------------------*
 * On the last call to Dnf, the input and output lists are the same.         *
 *---------------------------------------------------------------------------*)
val Dnf_no_swaps = Q.prove
`!L alist. ((alist,F) = Dnf L) ==> (alist=L)`
(REC_INDUCT_TAC Dnf_induction
   THEN RW_TAC flag_ss [Dnf_eqns]
   THEN IMP_RES_THEN MP_TAC func_to_prod_lem
   THEN RW_TAC flag_ss []);


(*---------------------------------------------------------------------------*
 * When no swaps get made, the arrangement of the list is correct.           *
 *---------------------------------------------------------------------------*)

val final_step_correct = Q.prove
`!L. ((L,F) = Dnf L)
      ==>
     (L = APPEND (filter ($= R) L) 
           (APPEND (filter ($= W) L) 
                    (filter ($= B) L)))`
(PROGRAM_TAC{rules=Dnf_eqns, induction=Dnf_induction}
  THEN POP_ASSUM MP_TAC 
  THEN RW_TAC flag_ss [filter_def,Dnf_eqns,prod_fg_var] 
  THENL [`(CONS B rst,F) = Dnf (CONS B rst)` by PROVE_TAC [lem],
         `(CONS B rst,F) = Dnf (CONS B rst)` by PROVE_TAC [lem],
         `(CONS W rst,F) = Dnf (CONS W rst)` by PROVE_TAC [lem],
         `(rst,F) = Dnf rst`                 by PROVE_TAC [lem]]
  THEN RES_THEN MP_TAC THEN REPEAT (POP_ASSUM (K ALL_TAC))
  THEN RW_TAC flag_ss [filter_def] THEN POP_ASSUM MP_TAC
  THENL [Cases_on `APPEND (filter ($= R) rst) (filter ($= W) rst)`,
         Cases_on `filter ($= R) rst`, Cases_on `filter ($= R) rst`]
  THEN RW_TAC flag_ss []
  THENL map IMP_RES_TAC [append_filters_lem,filter_lem,filter_lem,filter_lem]
  THEN PROVE_TAC (type_rws"colour"));




(*---------------------------------------------------------------------------*
 * Needed to implement a higher-order rewrite with a first order rewriter.   *
 * (The simplifier fails to deal with paired lets.)                          *
 *---------------------------------------------------------------------------*)
val let_lem = 
 BETA_RULE(Q.ISPECL 
 [`\h. h = APPEND (filter ($= R) L) 
             (APPEND (filter ($= W) L) (filter ($= B) L))`,
  `Dnf L`,
  `\alist changed. changed => (flag alist) | alist`]  pairTools.PULL_LET2);


(*---------------------------------------------------------------------------*
 * Correctness: All occurrences of R in "flag L" are before all              *
 * occurrences of W, which are before all occurrences of B. This is          *
 * expressible in terms of append:                                           *
 *                                                                           *
 *    !L. ?l1 l2 l3. (flag L = APPEND l1 (APPEND l2 l3)) /\                  *
 *                   (!x. mem x l1 ==> (x=R)) /\                             *
 *                   (!x. mem x l2 ==> (x=W)) /\                             *
 *                   (!x. mem x l3 ==> (x=B))                                *
 *                                                                           *
 * Witnesses for l1, l2, and l3 can be given explicitly by filtering L       *
 * for the particular colour.                                                *
 *                                                                           *
 *---------------------------------------------------------------------------*)

val flag_correct = Q.store_thm
("flag_correct",
`!L. ?l1 l2 l3. (flag L = APPEND l1 (APPEND l2 l3))  /\
                 (!x. mem x l1 ==> (x=R)) /\
                 (!x. mem x l2 ==> (x=W)) /\
                 (!x. mem x l3 ==> (x=B))`,
GEN_TAC 
 THEN Q.EXISTS_TAC`filter ($=R) L` 
 THEN Q.EXISTS_TAC`filter ($=W) L`
 THEN Q.EXISTS_TAC`filter ($=B) L` 
 THEN RW_TAC bool_ss [mem_filter]
 THEN Q.ID_SPEC_TAC `L`
 THEN PROGRAM_TAC{induction = flag_induction, rules=flag_eqn}
 THEN PURE_REWRITE_TAC[let_lem] THEN pairTools.LET_INTRO_TAC
 THEN MAP_EVERY Q.SPEC_TAC [(`x:colour list`,`alist`), (`y:bool`,`changed`)]
 THEN RW_TAC bool_ss [] THENL
 [`alist:colour list = FST(Dnf l0)` by PROVE_TAC pairTheory.pair_rws
   THEN RW_TAC bool_ss 
         [GSYM (REWRITE_RULE [permutation_def] Dnf_permutation)],
  IMP_RES_TAC Dnf_no_swaps THEN ZAP_TAC bool_ss [final_step_correct]]);


(*---------------------------------------------------------------------------*
 * We also need to specify that flag permutes its input. Otherwise it        *
 * could just return an empty list!                                          *
 *---------------------------------------------------------------------------*)

val let_lem1 = 
 BETA_RULE(Q.ISPECL 
 [`\h:colour list. permutation l0 h`,
  `Dnf l0`,
  `\l1 changed. changed => (flag l1) | l1`]  pairTools.PULL_LET2);

val perm_trans = REWRITE_RULE [TCTheory.transitive_def] 
                          permTheory.permutation_trans;


val flag_permutes = Q.store_thm("flag_permutes",
`!l. permutation l (flag l)`,
PROGRAM_TAC{induction = flag_induction, rules=flag_eqn}
 THEN PURE_REWRITE_TAC [let_lem1] THEN pairTools.LET_INTRO_TAC
 THEN MAP_EVERY Q.SPEC_TAC [(`x:colour list`,`alist`), (`y:bool`,`changed`)]
 THEN RW_TAC bool_ss [] THENL 
 [`alist:colour list = FST(Dnf l0)` by PROVE_TAC pairTheory.pair_rws THEN
  `permutation alist (flag alist)`  by (RES_TAC THEN PROVE_TAC[])    THEN
  PROVE_TAC [perm_trans,Dnf_permutation],
  PROVE_TAC [Dnf_no_swaps, permutation_refl]]);

Count.report (Count.read meter);

