(*---------------------------------------------------------------------------

      Definition and correctness of a naive quicksort algorithm.

 ---------------------------------------------------------------------------*)

app load ["bossLib", "Q", "sortingTheory"];

open bossLib sortingTheory permTheory listTheory combinTheory relationTheory;
infix 8 by;

(*---------------------------------------------------------------------------
     Misc. lemmas, should perhaps be part of listTheory.
 ---------------------------------------------------------------------------*)

val FILTER_KT = Q.prove(`!l. FILTER (K T) l = l`,
Induct THEN RW_TAC list_ss [FILTER,K_DEF]);

val LENGTH_FILTER_SUBSET = Q.prove(
`(!y. P y ==> Q y) ==> !L. LENGTH (FILTER P L) <= LENGTH (FILTER Q L)`,
DISCH_TAC THEN Induct 
  THEN RW_TAC list_ss [FILTER]
  THEN PROVE_TAC[]);

val MEM_FILTER = Q.prove
(`!P L x. MEM x (FILTER P L) = P x /\ MEM x L`,
Induct_on `L`  
 THEN RW_TAC list_ss [MEM,FILTER] 
 THEN PROVE_TAC [MEM]);

val MEM_APPEND_DISJ = Q.prove
(`!x l1 l2. MEM x (APPEND l1 l2) = MEM x l1 \/ MEM x l2`,
Induct_on `l1` THEN RW_TAC list_ss [APPEND,MEM] THEN PROVE_TAC[]);


(*---------------------------------------------------------------------------*
 * The property of a relation being total.                                   *
 *---------------------------------------------------------------------------*)

val total_def = Define `total R = !x y. R x y \/ R y x`;


(*---------------------------------------------------------------------------*
 * The quicksort algorithm.                                                  *
 *---------------------------------------------------------------------------*)

val qsort_def = 
 Hol_defn "qsort"
   `(qsort r [] = []) /\
    (qsort r (h::t) = 
        APPEND (qsort r (FILTER (\x. r x h) t))
          (h :: qsort r (FILTER (\x. ~(r x h)) t)))`;


(*---------------------------------------------------------------------------
       Termination proof.
 ---------------------------------------------------------------------------*)

val (qsort_eqns,qsort_ind) = 
Defn.tprove
  (qsort_def,
   WF_REL_TAC qsort_def `measure (LENGTH o SND)`
     THEN RW_TAC list_ss 
         [DECIDE `x<SUC y = x <= y`, o_DEF,K_DEF,
          REWRITE_RULE [FILTER_KT]
              (Q.INST [`Q` |-> `K T`] LENGTH_FILTER_SUBSET)]);


(*---------------------------------------------------------------------------*
 *           Properties of qsort                                             *
 *---------------------------------------------------------------------------*)

val qsort_MEM_stable = Q.prove
(`!R L x. MEM x (qsort R L) = MEM x L`,
  recInduct qsort_ind 
    THEN RW_TAC list_ss [qsort_eqns,MEM_APPEND_DISJ,MEM,MEM_FILTER] 
    THEN PROVE_TAC[]);


val qsort_perm = Q.store_thm
("qsort_PERM",
 `!R L. PERM L (qsort R L)`,
 recInduct qsort_ind 
  THEN RW_TAC list_ss [qsort_eqns,PERM_refl,APPEND]
  THEN MATCH_MP_TAC CONS_PERM 
  THEN MATCH_MP_TAC PERM_trans1
  THEN Q.EXISTS_TAC`APPEND (FILTER(\x. r x h) t) (FILTER(\x. ~r x h) t)`
  THEN RW_TAC std_ss [BETA_RULE(REWRITE_RULE[o_DEF] PERM_split),PERM_cong]);


val qsort_sorts = 
Q.store_thm
("qsort_sorts",
`!R L. transitive R /\ total R ==> SORTED R (qsort R L)`,
 recInduct qsort_ind 
  THEN RW_TAC list_ss [qsort_eqns,SORTED_def]
  THEN MATCH_MP_TAC SORTED_APPEND
  THEN IMP_RES_THEN (fn th => ASM_REWRITE_TAC [th]) SORTED_eq
  THEN RW_TAC list_ss [MEM_FILTER,MEM,qsort_MEM_stable] 
  THEN PROVE_TAC [transitive_def,total_def]);


(*---------------------------------------------------------------------------*
 * Bring everything together.                                                *
 *---------------------------------------------------------------------------*)

val qsort_correct = Q.store_thm
("qsort_correct", 
`!R. transitive R /\ total R ==> performs_sorting qsort R`,
PROVE_TAC
  [performs_sorting_def, qsort_perm, qsort_sorts]);
