(*---------------------------------------------------------------------------*
 *                                                                           *
 *       The Dutch National Flag                                             *
 *                                                                           *
 * The example is originally from Dijkstra, but is/was very popular as a     *
 * simple term rewriting system example. There is a linear arrangement of    *
 * objects coloured Red, White, or Blue. The term rewriting system           *
 *                                                                           *
 *          WR --> RW                                                        *
 *          BR --> RB                                                        *
 *          BW --> WB                                                        *
 *                                                                           *
 * always terminates when applied to the objects, and leaves them such that  *
 * all the Red ones come first, then the White ones, then the Blue ones.     *
 * When represented as a functional program, we essentially get a            *
 * bubblesort.                                                               *
 *---------------------------------------------------------------------------*)


(*---------------------------------------------------------------------------*
 * Load and open useful libraries and theories.                              *
 *---------------------------------------------------------------------------*)

app load ["bossLib", "permTheory"]; open bossLib listTheory permTheory;

(*---------------------------------------------------------------------------*
 * Define the colours.                                                       *
 *---------------------------------------------------------------------------*)

Hol_datatype `colour = R | W | B`;


(*---------------------------------------------------------------------------*
 * Define the swapping function.                                             *
 *---------------------------------------------------------------------------*)

val Swap_def =
 Define 
   `(Swap  []         = NONE) 
 /\ (Swap (W::R::rst) = SOME(R::W::rst))
 /\ (Swap (B::R::rst) = SOME(R::B::rst))
 /\ (Swap (B::W::rst) = SOME(W::B::rst))
 /\ (Swap (x::rst)    = OPTION_MAP (CONS x) (Swap rst))`;


val Swap_eqns = 
  let val Swap_NIL = CONJUNCT1 Swap_def
  in 
    CONJ Swap_NIL 
     (REWRITE_RULE [Swap_NIL,optionTheory.OPTION_MAP_DEF] 
             (CONJUNCT2 Swap_def))
  end;

val Swap_ind_thm = Q.prove
(`!P. P [] /\ P [B] /\ P [W] /\
      (!rst. P (W::R::rst))  /\ 
      (!rst. P (B::R::rst))  /\
      (!rst. P (B::W::rst))  /\ 
      (!v7.  P (B::v7) ==> P (B::B::v7)) /\
      (!v3.  P (B::v3) ==> P (W::B::v3)) /\
      (!v3.  P (W::v3) ==> P (W::W::v3)) /\ 
      (!rst. P rst ==> P (R::rst)) 
    ==>
        !v. P v`,
GEN_TAC THEN DISCH_TAC 
 THEN MATCH_MP_TAC (fetch "-" "Swap_ind")
 THEN RW_TAC std_ss []);


(*---------------------------------------------------------------------------*
 * Define the Flag function.                                                 *
 *---------------------------------------------------------------------------*)

val Flag_defn = 
 Defn.Hol_defn "Flag" 
               `Flag list = case (Swap list)
                             of NONE -> list 
                             || SOME slist -> Flag slist`;


(*---------------------------------------------------------------------------*
 * Termination (measure function courtesy of jrh).                           *
 *---------------------------------------------------------------------------*)

val Weight_def =
 Define
     `(Weight (R::rst) = 3 + 2*Weight rst) 
  /\  (Weight (W::rst) = 2 + 2*Weight rst)
  /\  (Weight (B::rst) = 1 + 2*Weight rst)
  /\  (Weight []       = 0)`;


(*---------------------------------------------------------------------------*)
(* Termination proof for Flag                                                *)
(*---------------------------------------------------------------------------*)

val (Flag_eqns,Flag_ind) = 
Defn.tprove
(Flag_defn,
 WF_REL_TAC `measure Weight`
  THEN recInduct Swap_ind_thm
  THEN RW_TAC std_ss [Swap_eqns,Weight_def] 
  THEN RW_TAC arith_ss [Weight_def]
  THEN RES_THEN MP_TAC 
  THEN DECIDE_TAC);

(*---------------------------------------------------------------------------*
 * Miscellaneous lemmas                                                      *
 *---------------------------------------------------------------------------*)

val MEM_FILTER = Q.prove
(`!P L x. MEM x (FILTER P L) = P x /\ MEM x L`,
Induct_on `L`  
 THEN RW_TAC list_ss [MEM,FILTER] 
 THEN PROVE_TAC [MEM]);

val FILTER_lem = Q.prove
(`!P l h t. (FILTER P l = h::t) ==> P h`,
Induct_on `l`
  THEN RW_TAC list_ss [FILTER]
  THEN PROVE_TAC[]);

val APPEND_FILTER_lem = Q.prove
(`!P Q l1 l2 h t. 
   (APPEND (FILTER P l1) (FILTER Q l2) = h::t) ==> P h \/ Q h`,
REPEAT GEN_TAC 
  THEN Cases_on `FILTER P l1` 
  THEN RW_TAC list_ss [FILTER]
  THEN PROVE_TAC [FILTER_lem]);

(*---------------------------------------------------------------------------*
 * Swap permutes its input when it returns a SOME'd item.                    *
 *---------------------------------------------------------------------------*)

val Swap_SOME = Q.prove
(`!l1 l2. (Swap l1 = SOME l2) ==> PERM l1 l2`,
recInduct Swap_ind_thm
  THEN BasicProvers.NORM_TAC list_ss [Swap_eqns,PERM_def,FILTER]
  THEN RES_THEN (fn th => REWRITE_TAC [GSYM th])
  THEN RW_TAC std_ss []);

(*---------------------------------------------------------------------------*
 * When no Swaps get made, the arrangement of the list is correct.           *
 *---------------------------------------------------------------------------*)

val Swap_NONE = Q.prove
(`!L. (Swap L = NONE)
       ==>
      (L = APPEND (FILTER ($= R) L) 
             (APPEND (FILTER ($= W) L) 
                       (FILTER ($= B) L)))`,
recInduct Swap_ind_thm 
  THEN BasicProvers.NORM_TAC list_ss [Swap_eqns,FILTER]
  THEN RES_THEN MP_TAC THEN REPEAT (WEAKEN_TAC (K true)) THENL 
  [Cases_on `APPEND (FILTER ($= R) v7) (FILTER ($= W) v7)` 
     THENL [ALL_TAC, IMP_RES_THEN MP_TAC APPEND_FILTER_lem],
   Cases_on `FILTER ($= R) v3` 
     THENL [ALL_TAC, IMP_RES_THEN MP_TAC FILTER_lem],
   Cases_on `FILTER ($= R) v3` 
     THENL [ALL_TAC, IMP_RES_THEN MP_TAC FILTER_lem]]
  THEN RW_TAC list_ss []);

(*---------------------------------------------------------------------------*
 * Correctness: All occurrences of R in "Flag L" are before all              *
 * occurrences of W, which are before all occurrences of B. This is          *
 * expressible in terms of append:                                           *
 *                                                                           *
 *    !L. ?l1 l2 l3. (Flag L = APPEND l1 (APPEND l2 l3)) /\                  *
 *                   (!x. mem x l1 ==> (x=R)) /\                             *
 *                   (!x. mem x l2 ==> (x=W)) /\                             *
 *                   (!x. mem x l3 ==> (x=B))                                *
 *                                                                           *
 *---------------------------------------------------------------------------*)

val Flag_correct = Q.store_thm
("Flag_correct",
`!L. ?l1 l2 l3. (Flag L = APPEND l1 (APPEND l2 l3))  /\
                 (!x. MEM x l1 ==> (x=R)) /\
                 (!x. MEM x l2 ==> (x=W)) /\
                 (!x. MEM x l3 ==> (x=B))`,
recInduct Flag_ind 
  THEN RW_TAC std_ss []
  THEN ONCE_REWRITE_TAC [Flag_eqns]
  THEN Cases_on `Swap list` 
  THEN RW_TAC std_ss [] 
  THEN MAP_EVERY Q.EXISTS_TAC 
         [`FILTER ($=R) list`, `FILTER ($=W) list`, `FILTER ($=B) list`]
  THEN RW_TAC std_ss [MEM_FILTER] THEN PROVE_TAC [Swap_NONE]);


(*---------------------------------------------------------------------------
         Flag is honest!
 ---------------------------------------------------------------------------*)

val Flag_PERM = Q.store_thm("Flag_PERM",
`!l. PERM l (Flag l)`,
recInduct Flag_ind THEN RW_TAC std_ss [] 
 THEN ONCE_REWRITE_TAC [Flag_eqns]
 THEN Cases_on `Swap list` THEN RW_TAC std_ss [PERM_refl]
 THEN PROVE_TAC [Swap_SOME,PERM_trans1]);;
