(*---------------------------------------------------------------------------
    Transformation of programs involving "unfold" : the fusion law. 
 ---------------------------------------------------------------------------*)

app load ["bossLib", "Q", "pairTools"]; open bossLib pairTools; 

infix 8 by; infix &&;
show_assums := true;


(*---------------------------------------------------------------------------
     Datatype of binary trees with data at each internal node.
 ---------------------------------------------------------------------------*)

Hol_datatype `btree 
                = LEAF 
                | NODE of btree => 'a => btree`;


(*---------------------------------------------------------------------------
    Standard primitive recursor over btree.
 ---------------------------------------------------------------------------*)

val btreeRec_def = 
 Define
   `(btreeRec LEAF (v:'a) (f:'a->'b->'a->'a)  = v)
 /\ (btreeRec (NODE t1 M t2) v f = f (btreeRec t1 v f) M (btreeRec t2 v f))`;



(*---------------------------------------------------------------------------
     "unfold" into a btree. The following is not correct, because "more" 
     and "dest" are not arguments to the function (i.e. things that
     are recursed on), but rather parameters that have to be filled in
     before the definition makes sense.

            unfold more f x = 
              if more x 
               then let (y,a,z) = dest x
                    in 
                     NODE (unfold more dest y) a (unfold more dest z)
               else LEAF

     Following is the right way to define the unfold schema. Since dest 
     and more are only free in the right hand side, they are treated
     as parameters.
 ---------------------------------------------------------------------------*)

val unfold_def = 
 Define
     `unfold (x:'a) = 
        if more x 
          then let (y1,b,y2) = dest x 
               in 
                  NODE (unfold y1) b (unfold y2)
          else LEAF`;



(*---------------------------------------------------------------------------
    val unfold_ind 
      = 
    [WF R, !x y1 b y2. more x /\ ((y1,b,y2) = dest x) ==> R y1 x,
           !x y1 b y2. more x /\ ((y1,b,y2) = dest x) ==> R y2 x]
    |- !P.
         (!x.
            (!y1 b y2. more x /\ ((y1,b,y2) = dest x) ==> P y1) /\
            (!y1 b y2. more x /\ ((y1,b,y2) = dest x) ==> P y2) ==>
            P x) ==>
         !v. P v
 ---------------------------------------------------------------------------*)

val unfold_ind = theorem "unfold_ind";


(*---------------------------------------------------------------------------
        "fusion" is just a generalization of unfold.
 ---------------------------------------------------------------------------*)

val fusion_def = 
 Define
     `fusion (x:'a) = 
         if more x 
           then let (y,i,z) = dest x
                in 
                   g (fusion y) (i:'b) (fusion z)
           else (c:'c)`;


(*---------------------------------------------------------------------------
     Prove that unfolding and then reducing is the same as doing 
     a fusion. We want to lift "let" bindings from the goal and 
     put them on the assumptions. To do that, we need to supply
     the following rewrite rules that lift a let that binds
     3 values. 
 ---------------------------------------------------------------------------*)

val LET3_RAND = Q.prove
(`!(P:'d->'e) (M:'a#'b#'c) N.
       P (let (x,y,z) = M in N x y z) = (let (x,y,z) = M in P (N x y z))`,
REWRITE_TAC[boolTheory.LET_DEF] 
 THEN CONV_TAC (DEPTH_CONV pairTools.GEN_BETA_CONV)
 THEN REWRITE_TAC[]);


val LET3_RATOR = Q.prove
(`!(M:'a1#'a2#'a3) (N:'a1->'a2->'a3->'b->'c) (b:'b). 
      (let (x,y,z) = M in N x y z) b = let (x,y,z) = M in N x y z b`,
REWRITE_TAC[boolTheory.LET_DEF] 
 THEN CONV_TAC (DEPTH_CONV pairTools.GEN_BETA_CONV)
 THEN REWRITE_TAC[]);


val fusion_law = Q.prove
(`!R dest more.
     WF R
     /\ (!x y1 b y2. more x /\ ((y1,b,y2) = dest x) ==> R y2 x) 
     /\ (!x y1 b y2. more x /\ ((y1,b,y2) = dest x) ==> R y1 x) 
     ==>
        !x c g. 
           btreeRec (unfold dest more x) c g = fusion c dest g more x`,
REPEAT GEN_TAC THEN STRIP_TAC 
  THEN recInduct unfold_ind THEN RW_TAC std_ss []
  THEN ONCE_REWRITE_TAC[unfold_def] THEN ONCE_REWRITE_TAC[fusion_def] 
  THEN LET_EQ_TAC [LET3_RATOR,LET3_RAND]
  THEN Q.PAT_ASSUM `x = y` (SUBST_ALL_TAC o SYM) 
  THEN RW_TAC std_ss [btreeRec_def]
  THEN PROVE_TAC []);
