(*---------------------------------------------------------------------------
           Transformations for linear recursion.
 ---------------------------------------------------------------------------*)

app load ["bossLib", "Q"]; open bossLib; 

infix 8 by; infix &&;
show_assums := true;

(*---------------------------------------------------------------------------
       Definitions of schemas for linear recursion, 
       continuation-passing style, and accumulator style.
 ---------------------------------------------------------------------------*)

val linRec_eqn = 
 Define
    `linRec (x:'a) = if atomic x then A x 
                     else join (linRec (dest x)) (D x:'b)`;

val cpRec_def0 = 
 Define
    `cpRec (x:'a, (f:'b -> 'c)) = 
        if atomic x then f (A x) 
        else cpRec (dest x, \u:'b. f (join u (D x:'b)))`;

val accRec_def0 = 
 Define
    `accRec (x:'a, (u:'b)) = 
        if atomic x then join (A x) u
        else accRec (dest x, join (D x:'b) u)`;


(*---------------------------------------------------------------------------
      Minor massaging of definitions in order to clean up the TCs.
 ---------------------------------------------------------------------------*)

fun ID_SPEC thm = SPEC (#Bvar(Rsyntax.dest_forall(concl thm))) thm;

val [cpRec_eqn, cpRec_ind] = CONJUNCTS
(UNDISCH
 (REWRITE_RULE pairTheory.pair_rws
   (BETA_RULE (REWRITE_RULE[relationTheory.inv_image_def]
      (REWRITE_RULE[UNDISCH (ISPEC (Term`FST:'a#('b->'c)->'a`)
                              (ID_SPEC relationTheory.WF_inv_image))]
        (Rsyntax.INST [Term`R :'a # ('b -> 'c) -> 'a # ('b -> 'c) -> bool`
                |->
               Term`inv_image R (FST :'a # ('b -> 'c) -> 'a)`]
         (DISCH_ALL (CONJ cpRec_def0 (theorem "cpRec_ind")))))))));

val [accRec_eqn, accRec_ind] = CONJUNCTS
(UNDISCH
 (REWRITE_RULE pairTheory.pair_rws
   (BETA_RULE (REWRITE_RULE[relationTheory.inv_image_def]
      (REWRITE_RULE[UNDISCH (ISPEC (Term`FST:'a#'b->'a`)
                              (ID_SPEC relationTheory.WF_inv_image))]
        (Rsyntax.INST [Term`R :'a#'b -> 'a#'b -> bool`
                 |->
               Term`inv_image R (FST :'a#'b -> 'a)`]
         (DISCH_ALL (CONJ accRec_def0 (theorem "accRec_ind")))))))));


(* To overwrite old theorems

    save_thm("cpRec_def", cpRec_eqn);
    save_thm("cpRec_ind", cpRec_ind);
    save_thm("accRec_def", accRec_eqn);
    save_thm("accRec_ind", accRec_ind);
*)

val cpRec_eq_linRec = Q.prove
(`!R atomic A join dest D.
    WF R 
     /\ (!x. ~(atomic x) ==> R (dest x) x)
     ==> 
      !x f. cpRec A D atomic dest join (x, f) 
              = 
            f (linRec A D atomic dest join x)`,
REPEAT GEN_TAC THEN STRIP_TAC
  THEN recInduct (theorem "linRec_ind")
  THEN RW_TAC std_ss [] 
  THEN ONCE_REWRITE_TAC[linRec_eqn] THEN ONCE_REWRITE_TAC[cpRec_eqn] 
  THEN RW_TAC std_ss []);


val accRec_eq_cpRec = Q.prove
(`!R atomic A join dest D.
    WF R 
    /\ (!x. ~atomic x ==> R (dest x) x) 
    /\ (!p q r:'b. join p (join q r) = join (join p q) r)
     ==> 
      !x u. accRec A D atomic dest join (x,u) 
              = 
            cpRec A D atomic dest join (x, \w. join w u)`,
REPEAT GEN_TAC THEN STRIP_TAC
  THEN recInduct accRec_ind THEN RW_TAC std_ss []
  THEN ONCE_REWRITE_TAC[cpRec_eqn] THEN ONCE_REWRITE_TAC[accRec_eqn] 
  THEN RW_TAC std_ss []);


(*---------------------------------------------------------------------------
     Ergo, we have the following equality between linear and 
     accumulator recursions ... this can also be proved directly
     via an equally easy proof.
 ---------------------------------------------------------------------------*)

val linRec_eq_accRec = Q.prove
(`!R atomic A join dest D.
    WF R 
    /\ (!x. ~(atomic x) ==> R (dest x) x)
    /\ (!p q r:'b. join p (join q r) = join (join p q) r)
     ==> 
      !x u. join (linRec A D atomic dest join x) u 
              = 
            accRec A D atomic dest join (x,u)`,
REPEAT STRIP_TAC   (* weakness in solver forces use of IMP_RES_THEN *)
  THEN IMP_RES_THEN (fn th => RW_TAC std_ss [th]) accRec_eq_cpRec
  THEN IMP_RES_THEN (fn th => RW_TAC std_ss [th]) cpRec_eq_linRec);


(*---------------------------------------------------------------------------
              Two versions of reverse.
 ---------------------------------------------------------------------------*)

val rev_def = 
 Define
    `rev = linRec (*  A  *)    (\x.x)       
                  (*  D  *)    (\l. [HD l]) 
                  (* atomic *) NULL
                  (*  dest  *) TL           
                  (*  join  *) APPEND`;


val frev_def = 
 Define
    `frev l a = accRec (*  A  *)    (\x.x)
                       (*  D  *)    (\l. [HD l])
                       (* atomic *)  NULL
                       (*  dest  *)  TL
                       (*  join  *)  APPEND
                       (l,a)`;


(*---------------------------------------------------------------------------
      Equivalence of the two forms of reverse. First, instantiate
      the program transformation "sufficiently".
 ---------------------------------------------------------------------------*)

val lem0 = GEN_ALL (DISCH_ALL (GSYM 
             (UNDISCH_ALL (SPEC_ALL 
               (Q.ISPEC `measure LENGTH` linRec_eq_accRec)))));

val lem1 = Q.prove
(`!x. ~NULL x ==> measure LENGTH (TL x) x`,
Cases 
  THEN RW_TAC list_ss [prim_recTheory.measure_def,
                       relationTheory.inv_image_def]);


val rev_eq_frev = Q.prove
(`!l x. rev l = frev l []`,
RW_TAC list_ss [rev_def, frev_def, lem0, lem1, prim_recTheory.WF_measure]);


(*---------------------------------------------------------------------------
    The natural way to apply these rewrites seems to be to let the
    user write something in a simple style and then let the
    system do some waily higher-order matching in order to choose
    and apply a transformation. Even after the transformation has 
    been applied, the new recursion equations may have to be simplified
    in order to be intelligible to the user.
 ---------------------------------------------------------------------------*)
