(*---------------------------------------------------------------------------*
 *                                                                           *
 *       The Dutch National Flag                                             *
 *                                                                           *
 * The example is originally from Dijkstra, but is/was very popular as a     *
 * simple term rewriting system example. There is a linear arrangement of    *
 * objects coloured Red, White, or Blue. The term rewriting system           *
 *                                                                           *
 *          WR --> RW                                                        *
 *          BR --> RB                                                        *
 *          BW --> WB                                                        *
 *                                                                           *
 * always terminates when applied to the objects, and leaves them such that  *
 * all the Red ones come first, then the White ones, then the Blue ones.     *
 * When represented as a functional program, we essentially get a            *
 * bubblesort.                                                               *
 *---------------------------------------------------------------------------*)


(*---------------------------------------------------------------------------*
 * Load and open useful libraries and theories.                              *
 *---------------------------------------------------------------------------*)

app load ["bossLib", "permTheory", "Q"]; 
open bossLib;

infix 8 by;

open listTheory permTheory;

(*---------------------------------------------------------------------------*
 * Define Red, White, and Blue to be the colours.                            *
 *---------------------------------------------------------------------------*)

Hol_datatype `colour = R | W | B`;


(*---------------------------------------------------------------------------*
 * Define the swapping function.                                             *
 *---------------------------------------------------------------------------*)

val Swap_def =
 Define 
   `(Swap  []         = NONE) 
 /\ (Swap (W::R::rst) = SOME(R::W::rst))
 /\ (Swap (B::R::rst) = SOME(R::B::rst))
 /\ (Swap (B::W::rst) = SOME(W::B::rst))
 /\ (Swap (x::rst)    = option_APPLY (CONS x) (Swap rst))`;

val Swap_eqns = 
  let val Swap_NIL = CONJUNCT1 Swap_def
  in 
    CONJ Swap_NIL 
     (REWRITE_RULE [Swap_NIL,optionTheory.option_APPLY_DEF] 
             (CONJUNCT2 Swap_def))
  end;

val Swap_ind_thm = Q.prove
(`!P. P [] /\ P [B] /\ P [W] /\
      (!rst. P (W::R::rst))   /\ 
      (!rst. P (B::R::rst))   /\
      (!rst. P (B::W::rst)) /\ (!v7. P (B::v7) ==> P (B::B::v7)) /\
      (!v3. P (B::v3) ==> P (W::B::v3)) /\
      (!v3. P (W::v3) ==> P (W::W::v3)) /\ 
      (!rst. P rst ==> P (R::rst)) 
    ==>
        !v. P v`,
GEN_TAC THEN DISCH_TAC 
 THEN MATCH_MP_TAC (theorem "Swap_ind")
 THEN RW_TAC std_ss []);


(*---------------------------------------------------------------------------*
 * Define the Flag function. Note that eta-expansion "\l. Flag l" needed.    *
 *                                                                           *
 * ML equivalent: fun Flag l = case Swap list                                *
 *                              of NONE    => l                              *
 *                               | SOME l' => Flag l'                        *
 *---------------------------------------------------------------------------*)

val Flag_defn = 
 Defn.Hol_defn "Flag" 
               `Flag list = option_case list (\l. Flag l) (Swap list)`;


(*---------------------------------------------------------------------------*
 * Termination (measure function courtesy of jrh).                           *
 *---------------------------------------------------------------------------*)

val Weight_def =
 Define
     `(Weight (R::rst) = 3 + 2*Weight rst) 
  /\  (Weight (W::rst) = 2 + 2*Weight rst)
  /\  (Weight (B::rst) = 1 + 2*Weight rst)
  /\  (Weight []       = 0)`;


val option_APPLY_lem1 = Q.prove
(`!f opt x. (option_APPLY f opt = SOME x) ==> ?y. opt = SOME y`,
Cases_on `opt` THEN RW_TAC std_ss []);

val option_APPLY_lem2 = Q.prove
(`!f opt x. (option_APPLY f opt = NONE) ==> (opt = NONE)`,
Cases_on `opt` THEN RW_TAC std_ss []);


val option_APPLY_eq_SOME_TAC =
     IMP_RES_THEN (CHOOSE_THEN SUBST_ALL_TAC) option_APPLY_lem1
        THEN Q.PAT_ASSUM `M = N` MP_TAC 
        THEN RW_TAC std_ss [];


val (Flag_eqns,Flag_ind) = 
Defn.tprove
(Flag_defn,
 WF_REL_TAC Flag_defn `measure Weight`
  THEN recInduct Swap_ind_thm
  THEN RW_TAC std_ss [Swap_eqns,Weight_def] 
  THENL [ALL_TAC,ALL_TAC,ALL_TAC,
         option_APPLY_eq_SOME_TAC,option_APPLY_eq_SOME_TAC,
         option_APPLY_eq_SOME_TAC,option_APPLY_eq_SOME_TAC]
   THEN RW_TAC arith_ss [Weight_def]);

val _ = save_thm("Flag_eqns", Flag_eqns);
val _ = save_thm("Flag_ind", Flag_ind);

(*---------------------------------------------------------------------------*
 * Some miscellaneous stuff used in later proofs.	                     *
 *---------------------------------------------------------------------------*)

val MEM_FILTER = Q.prove
(`!P L x. MEM x (FILTER P L) = P x /\ MEM x L`,
Induct_on `L`  
 THEN RW_TAC list_ss [MEM,FILTER] 
 THEN PROVE_TAC [MEM]);

val FILTER_lem = Q.prove
(`!P l h t. (FILTER P l = h::t) ==> P h`,
Induct_on `l`
  THEN RW_TAC list_ss [FILTER]
  THEN PROVE_TAC[]);

val APPEND_FILTER_lem = Q.prove
(`!P Q l1 l2 h t. 
   (APPEND (FILTER P l1) (FILTER Q l2) = h::t) ==> P h \/ Q h`,
REPEAT GEN_TAC 
  THEN Cases_on `FILTER P l1` 
  THEN RW_TAC list_ss [FILTER]
  THEN PROVE_TAC [FILTER_lem]);


(*---------------------------------------------------------------------------*
 * Swap permutes its input when it returns a SOME'd item.                    *
 *---------------------------------------------------------------------------*)

val Swap_SOME = Q.prove
(`!l1 l2. (Swap l1 = SOME l2) ==> PERM l1 l2`,
 recInduct Swap_ind_thm
  THEN RW_TAC std_ss [Swap_eqns]
  THEN ((RW_TAC list_ss [PERM_def,FILTER] THEN RW_TAC list_ss [] THEN NO_TAC)
       ORELSE
       (option_APPLY_eq_SOME_TAC THEN PROVE_TAC [PERM_mono])));


(*---------------------------------------------------------------------------*
 * When no Swaps get made, the arrangement of the list is correct.           *
 *---------------------------------------------------------------------------*)

val Swap_NONE = Q.prove
(`!L. (Swap L = NONE)
       ==>
      (L = APPEND (FILTER ($= R) L) 
             (APPEND (FILTER ($= W) L) 
                       (FILTER ($= B) L)))`,
recInduct Swap_ind_thm THEN RW_TAC std_ss [Swap_eqns]
 THEN ((RW_TAC list_ss [FILTER,APPEND] THEN NO_TAC)
       ORELSE IMP_RES_TAC option_APPLY_lem2 THEN RES_THEN MP_TAC)
 THEN IMP_RES_TAC option_APPLY_lem2 THEN RES_THEN MP_TAC THENL 
  [REWRITE_TAC (APPEND_ASSOC::FILTER::type_rws"colour")
     THEN Cases_on `APPEND (FILTER ($= R) v7) (FILTER ($= W) v7)`
     THEN RW_TAC std_ss [APPEND]
     THEN IMP_RES_THEN MP_TAC APPEND_FILTER_lem,
   REWRITE_TAC (FILTER::type_rws"colour")
     THEN Cases_on `FILTER ($= R) v3` THEN RW_TAC std_ss [APPEND]
     THEN IMP_RES_THEN MP_TAC FILTER_lem,
   REWRITE_TAC (FILTER::type_rws"colour")
     THEN Cases_on `FILTER ($= R) v3` THEN RW_TAC std_ss [APPEND]
     THEN IMP_RES_THEN MP_TAC FILTER_lem,
   REWRITE_TAC (FILTER::type_rws"colour")
     THEN Cases_on `FILTER ($= R) rst` THEN RW_TAC std_ss [APPEND]
     THEN IMP_RES_THEN MP_TAC FILTER_lem]
 THEN RW_TAC std_ss []);



(*---------------------------------------------------------------------------*
 * Correctness: All occurrences of R in "Flag L" are before all              *
 * occurrences of W, which are before all occurrences of B. This is          *
 * expressible in terms of append:                                           *
 *                                                                           *
 *    !L. ?l1 l2 l3. (Flag L = APPEND l1 (APPEND l2 l3)) /\                  *
 *                   (!x. mem x l1 ==> (x=R)) /\                             *
 *                   (!x. mem x l2 ==> (x=W)) /\                             *
 *                   (!x. mem x l3 ==> (x=B))                                *
 *                                                                           *
 *---------------------------------------------------------------------------*)

val Flag_correct = Q.store_thm
("Flag_correct",
`!L. ?l1 l2 l3. (Flag L = APPEND l1 (APPEND l2 l3))  /\
                 (!x. MEM x l1 ==> (x=R)) /\
                 (!x. MEM x l2 ==> (x=W)) /\
                 (!x. MEM x l3 ==> (x=B))`,
recInduct Flag_ind 
  THEN RW_TAC std_ss []
  THEN ONCE_REWRITE_TAC [Flag_eqns]
  THEN Cases_on `Swap list` 
  THEN RW_TAC std_ss [] 
  THEN MAP_EVERY Q.EXISTS_TAC 
         [`FILTER ($=R) list`, `FILTER ($=W) list`, `FILTER ($=B) list`]
  THEN RW_TAC std_ss [MEM_FILTER] THEN PROVE_TAC [Swap_NONE]);


(*---------------------------------------------------------------------------
         Flag is honest!
 ---------------------------------------------------------------------------*)

val Flag_PERM = Q.store_thm("Flag_PERM",
`!l. PERM l (Flag l)`,
recInduct Flag_ind THEN RW_TAC std_ss [] 
 THEN ONCE_REWRITE_TAC [Flag_eqns]
 THEN Cases_on `Swap list` THEN RW_TAC std_ss [PERM_refl]
 THEN PROVE_TAC [Swap_SOME,PERM_trans1]);;

