(*---------------------------------------------------------------------------*
 * Setting information for types.                                            *
 *                                                                           *
 * Here we fill in the "boolify" entries in theTypeBase for some types       *
 * built before lists: bool,prod,sum,num,option, ...                         *
 *---------------------------------------------------------------------------*)

structure Boolify =
struct

local
  open HolKernel boolLib Parse pairSyntax numSyntax listSyntax combinSyntax
       arithmeticTheory
  infix |->; infixr -->;

val ERR = mk_HOL_ERR "Boolify";

(*---------------------------------------------------------------------------
     Constructor encodings 
 ---------------------------------------------------------------------------*)

val bit_list = Type `:bool list`;

val encT     = Term `[T]`   (* booleans *)
val encF     = Term `[F]`
val encNone  = Term `[F]`   (* options *)
val encSome  = Term `[T]`
val encB2    = Term `[F]`   (* numbers *)
val encB1    = Term `[T;F]`
val encZ     = Term `[T;T]`
val encNil   = Term `[F]`   (* lists *)
val encCons  = Term `[T]`
val encInl   = Term `[T]`   (* sums *)
val encInr   = Term `[F]`


(*---------------------------------------------------------------------------
        Booleans
 ---------------------------------------------------------------------------*)

  val bool_info = Option.valOf (TypeBase.read "bool");
  val bool_boolify_info = 
         (rator(mk_bool_case(encT,encF,mk_arb bool)),
          TypeBasePure.ORIG boolTheory.bool_case_thm);
  val bool_info' = TypeBasePure.put_boolify bool_boolify_info bool_info;

(*---------------------------------------------------------------------------
        Pairs
 ---------------------------------------------------------------------------*)

  val prod_boolify_info =
    (let val f = mk_var("f", alpha --> bit_list)
         val g = mk_var("g", beta --> bit_list)
         val x = mk_var("x", alpha)
         val y = mk_var("y", beta)
     in list_mk_abs([f,g], mk_pabs(mk_pair(x,y),
          listSyntax.mk_append (mk_comb(f,x),mk_comb(g,y))))
     end,
     TypeBasePure.ORIG pairTheory.UNCURRY_DEF)

  val prod_info' = TypeBasePure.put_boolify prod_boolify_info 
                     (Option.valOf (TypeBase.read"prod"))

(*---------------------------------------------------------------------------
        Sums
 ---------------------------------------------------------------------------*)

val cons_bool = inst [alpha |-> bool] listSyntax.cons_tm;

  val sum_info = Option.valOf(TypeBase.read "sum");
  val sum_case = prim_mk_const{Name="sum_case", Thy="sum"}
  val sum_case_into_Bool = inst [alpha |-> bit_list] sum_case
  val f = mk_var("f",beta --> bit_list)
  val g = mk_var("g",gamma --> bit_list)
  val s = mk_var("s",mk_thy_type{Tyop="sum",Thy="sum",Args=[beta, gamma]})
  val t1 = mk_o(mk_comb(cons_bool,T),f)
  val t2 = mk_o(mk_comb(cons_bool,F),g)
  val tm = list_mk_abs ([f,g,s], list_mk_comb(sum_case_into_Bool,[t1,t2,s]))
  val sum_boolify_info = (tm,TypeBasePure.ORIG sumTheory.sum_case_def)
  val sum_info' = TypeBasePure.put_boolify sum_boolify_info sum_info

(*---------------------------------------------------------------------------
        Options
 ---------------------------------------------------------------------------*)

  val option_info = Option.valOf (TypeBase.read "option")
  val option_case_tm = prim_mk_const{Name="option_case",Thy="option"}
  val option_boolify_info =
       (let val f = mk_var("f",alpha --> bit_list)
            val t1 = mk_o(mk_comb(cons_bool,T),f)
        in mk_abs(f,list_mk_comb(inst [beta |-> bit_list] option_case_tm,
                    [listSyntax.mk_list([F],bool),t1]))
        end,
        TypeBasePure.ORIG optionTheory.option_case_def)
  val option_info' = TypeBasePure.put_boolify option_boolify_info option_info


(*---------------------------------------------------------------------------
        Lists
 ---------------------------------------------------------------------------*)

val listToBool = 
  Define
     `(listToBool f [] = [F])
   /\ (listToBool f (h::t) = T::APPEND (f h) (listToBool f t))`;


val list_info = Option.valOf (TypeBase.read "list")
val list_boolify_info =
      (mk_const("listToBool",
                Type `:('a -> bool list) -> 'a list -> bool list`),
      TypeBasePure.ORIG listToBool)
val list_info' = TypeBasePure.put_boolify list_boolify_info list_info;


(*---------------------------------------------------------------------------
        Nums (Norrish numeral encoding)
 ---------------------------------------------------------------------------*)

val EncodeNum = 
 Defn.tprove
  (Hol_defn "EncodeNum"
    `EncodeNum (n:num) = 
         if n=0 then [T;T]
                else if EVEN n then F::EncodeNum ((n-2) DIV 2)
                               else T::F::EncodeNum ((n-1) DIV 2)`,
   WF_REL_TAC `$<`
    THEN RW_TAC std_ss []
    THEN `?j. n = SUC j` by PROVE_TAC [num_CASES]
    THEN IMP_RES_TAC EVEN_EXISTS
    THEN RW_TAC arith_ss [SUC_SUB1,MULT_DIV,DIV_LESS_EQ,
                          DECIDE (Term `2n*m - 2n = (m-1n)*2n`),
                          DECIDE (Term `x < SUC y = x <= y`)]);

  (*--------------------------------------------------------------------
       Termination proof can also go: 

           WF_REL_TAC `$<` THEN intLib.COOPER_TAC

       but then we'd need integers.
   ----------------------------------------------------------------------*)


  val num_info = Option.valOf (TypeBase.read "num")
  val num_boolify_info =
       (mk_const("EncodeNum",numSyntax.num --> bit_list),
        TypeBasePure.ORIG (fst EncodeNum))
  val num_info' = TypeBasePure.put_boolify num_boolify_info num_info;

(*---------------------------------------------------------------------------
      The unit type is cool because it consumes no space in the
      target list: the type has all the information! (Thanks to Joe
      Hurd for pointing this out.) 
 ---------------------------------------------------------------------------*)

  val unit_boolify_info =
       (mk_abs(mk_var("v",mk_thy_type{Tyop="one", Thy="one",Args=[]}),
               mk_nil bit_list),
        TypeBasePure.ORIG boolTheory.REFL_CLAUSE);

in
   val _ = TypeBase.write bool_info'
   val _ = TypeBase.write prod_info'
   val _ = TypeBase.write option_info'
   val _ = TypeBase.write list_info'
   val _ = TypeBase.write sum_info'
   val _ = TypeBase.write num_info'
end

(*
(*---------------------------------------------------------------------------
    Map a HOL type (ty) into a term having type :ty -> bool list.
 ---------------------------------------------------------------------------*)

local fun tyboolify_env db = 
            Option.map fst o 
            Option.composePartial (TypeBasePure.boolify_of,
                                   TypeBasePure.get db)
      fun undef _ = raise ERR "type_boolify" "unknown type"
      fun theta ty = if is_vartype ty 
                        then raise ERR "type_boolify" "type variable"
                        else NONE
in
fun type_boolify db = TypeBasePure.typeValue (theta,tyboolify_env db,undef)
end

fun boolify tm = 
  let val db = TypeBase.theTypeBase()
      val f = type_boolify db (type_of tm)
  in EVAL (mk_comb(f,tm))
  end;

CONV_RULE (REDEPTH_CONV pairLib.PAIRED_BETA_CONV)
          (boolify (Term `[(1,2,3,4) ; (5,6,7,8)]`));
boolify (Term `(SOME F, 123, [F;T])`);
boolify (Term `[]:num list`);

Tasks: 

* prefix-free code generation
* define_boolify in Datatype
* 

*)


(*---------------------------------------------------------------------------
       Decoding
 ---------------------------------------------------------------------------*)

val DecodeBool =
 Define
    `(DecodeBool (T::t) = (T,t))
 /\  (DecodeBool (F::t) = (F,t))
 /\  (DecodeBool els    = (ARB,[]))`;


val DecodeOption =
 Define
    `(DecodeOption f (F::t) = (NONE,t))
 /\  (DecodeOption f (T::t) = let (v,t') = f t in (SOME v, t'))
 /\  (DecodeOption f els    = (ARB,[]))`;


val DecodeSum =
 Define
    `(DecodeSum f g (T::t) = let (u:'a, t':bool list) = f t in (INL u, t'))
 /\  (DecodeSum f g (F::t) = let (v:'b, t') = g t in (INR v, t'))
 /\  (DecodeSum f g els    = (ARB,[]))`;


val DecodePair =
 Define
    `DecodePair (f:bool list -> 'a#bool list) 
                (g:bool list -> 'b#bool list)
                (els:bool list) 
      = let (x,t1) = f els in 
        let (y,t2) = g t1
        in ((x,y),t2)`;

val DecodeNum =
 Define
    `(DecodeNum (T::F::rst) = let (v,rst') = DecodeNum rst in (2*v + 1, rst'))
 /\  (DecodeNum (F::rst)    = let (v,rst') = DecodeNum rst in (2*v + 2, rst'))
 /\  (DecodeNum (T::T::rst) = (0n,rst))
 /\  (DecodeNum els         = (ARB,[]))`;


val DecodeList =
 TotalDefn.DefineSchema
    `(DecodeList (F::t) = ([],t))
 /\  (DecodeList (T::t) = 
        let (h,t1) = f t in 
        let (tail,t2) = DecodeList t1 
        in (h::tail, t2))
 /\  (DecodeList els = (ARB,[]))`;


val DecodeList1 = 
 REWRITE_RULE [prim_recTheory.WF_measure]
  (DISCH_ALL
    (INST [Term `R:bool list -> bool list -> bool` |-> 
           Term `measure (LENGTH:bool list -> num)`,
           Term `f:bool list -> num # bool list`  |-> Term `DecodeNum`]
      (INST_TYPE [alpha |-> num] DecodeList)));

val ind_thm = fetch "-" "DecodeNum_ind";

open pairTheory pairTools;

val lemma = prove
(fst(dest_imp(concl DecodeList1)),
 recInduct ind_thm
  THEN RW_TAC list_ss [DecodeNum,prim_recTheory.measure_def,
                       relationTheory.inv_image_def]
  THEN Q.PAT_ASSUM `x = y` MP_TAC 
  THEN LET_EQ_TAC [LET2_RATOR,LET2_RAND]
  THEN RES_TAC 
  THEN RW_TAC list_ss []);

val DecodeNumList = MP DecodeList1 lemma;

val _ = computeLib.add_funs [DecodeNumList];

(* 
EVAL (Term `DecodeList DecodeNum ^(rhs(concl(boolify (Term`[1;2;3]`))))`);
EVAL (Term `DecodePair (DecodeList DecodeNum) 
                       (DecodeOption DecodeBool)
              ^(rhs(concl(boolify (Term`([1;2;3],SOME T)`))))`);
*)


(*---------------------------------------------------------------------------
  What's it look like for a nested datatype?
 ---------------------------------------------------------------------------*)

Hol_datatype `tree = Node of 'a => tree list`;

(*---------------------------------------------------------------------------
   Alternate induction scheme.

     !P. ((!x. MEM x l ==> P x) ==> !h. P (Node h l)) ==> !t. P t

   Size of a tree

     tree_size f (Node x tlist) = 1 + f x + list_size (tree_size f) tlist
 ---------------------------------------------------------------------------*)



(*---------------------------------------------------------------------------
    Need to install second order congruence theorems for TC extraction.
    Both listToBool and DecodeList need these.
 ---------------------------------------------------------------------------*)

val listToBool_cong = Q.prove
(`!M N f f'.
      (M = N) /\ 
      (!x. MEM x N ==> (f x = f' x)) 
      ==>
      (listToBool f M = listToBool f' N)`,
 Induct THEN RW_TAC list_ss [listToBool]
        THEN RW_TAC list_ss [listToBool]);

val _ = DefnBase.write_congs(listToBool_cong::DefnBase.read_congs());

val EncodeTree_defn = 
  Hol_defn 
     "EncodeTree"
     `EncodeTree f (Node x tlist) = 
          APPEND (f x) (listToBool (EncodeTree f) tlist)`;

val EncodeTree = Defn.tprove
 (EncodeTree_defn,
  WF_REL_TAC `^(hd(tl (TotalDefn.guessR EncodeTree_defn)))`
    THEN Induct THEN RW_TAC list_ss [fetch "-" "tree_size_def"]
    THENL [ALL_TAC, RES_THEN MP_TAC] 
    THEN DECIDE_TAC);

????????????? Coinductive problems : 

               what is the congruence rulefor DecodeList

DecodeList f returns (alist,rst)
???????????????????

val DecodeList_ind = fetch "-" "DecodeList_ind";

val DecodeList_cong = 
set_goal(hyp DecodeList_ind,
Term`!M N g g'.
      (M = N) /\ 
      (!x. LENGTH x < LENGTH N ==> (g x = g' x))
      ==>
      (DecodeList g M = DecodeList g' N)`);

e (recInduct DecodeList_ind);
expandf (RW_TAC list_ss [DecodeList]);
        THEN RW_TAC list_ss [DecodeList]);

val _ = DefnBase.write_congs(DecodeList_cong::DefnBase.read_congs());


val DecodeTree_defn = 
TotalDefn.DefineSchema
   `DecodeTree els = 
      let (x,els_1) = f els in
      let (terml,els2) = DecodeList DecodeTree els_1
      in (Node x terml, els2)`;


*)

Might need to instantiate f before anything can even be defined ... Nah ...

end
