(*---------------------------------------------------------------------------*
 * Define a HOL datatype and derive a bunch of theorems from it. Return      *
 * an element of the "hol_datatype_facts" type.                              *
 *                                                                           *
 * Examples of use:                                                          *
 *                                                                           *
 *   local open Datatype                                                     *
 *   in                                                                      *
 *   val term_ty_def =                                                       *
 *       Hol_datatype `term = Var of 'a                                      *
 *                          | Const of 'a                                    *
 *                          | App of term => term`                           *
 *                                                                           *
 *   val subst_ty_def =                                                      *
 *       Hol_datatype  `subst = Fail | Subst of ('a#'a term) list`           *
 *   end;                                                                    *
 *                                                                           *
 *                                                                           *
 * Returns a great deal of information about the datatype: theorems,         *
 * definition of case-constant, induction tactics, etc.                      *
 *                                                                           *
 * Side-effects: it adds the definition of the type and the definition       *
 * of the case-construct to the current theory.                              *
 *                                                                           *
 * Notice that, at least using the current mechanism for defining types,     *
 * a great many theories get loaded in: numbers, lists, trees, etc. when     *
 * this module is loaded. If your formalization doesn't want to have these   *
 * as parents, then Facts.mk_facts should be used directly.                  *
 *---------------------------------------------------------------------------*)

structure Datatype :> Datatype =
struct

open HolKernel Drule Tactical Tactic Conv Prim_rec ConstrProofs;

infix ## |-> THEN THENL;
infixr -->;

type term = Term.term
type tyinfo = Facts.tyinfo
type hol_type = Type.hol_type
type factBase = Facts.factBase;
type 'a quotation = term frag list;

 
fun ERR func mesg = 
    HOL_ERR{origin_structure = "Datatype",
            origin_function = func,
            message = mesg};

val defn_const = 
  #1 o strip_comb o lhs o #2 o strip_forall o hd o strip_conj o concl;

fun join f g x =
  case (g x) 
   of NONE => NONE
    | SOME y => (case (f y)
                  of NONE => NONE
                   | SOME(x,_) => SOME x);

fun num_variant vlist v =
  let val counter = ref 0
      val {Name,Ty} = dest_var v
      val slist = ref (map (#Name o dest_var) vlist)
      fun pass str = 
         if (mem str (!slist)) 
         then ( counter := !counter + 1;
                pass (Lib.concat Name (Lib.int_to_string(!counter))))
         else (slist := str :: !slist; str)
  in 
  mk_var{Name=pass Name,  Ty=Ty}
  end;


(*---------------------------------------------------------------------------*
 * Define a case constant for a datatype. This is used by TFL's              *
 * pattern-matching translation and are generally useful as replacements     *
 * for "destructor" operations.                                              *
 *---------------------------------------------------------------------------*)

fun define_case ax =
   let val exu = snd(strip_forall(concl ax))
       val {Rator,Rand} = dest_comb exu
       val {Name = "?!",...} = dest_const Rator
       val {Bvar,Body} = dest_abs Rand
       val ty = type_of Bvar
       val (dty,rty) = Type.dom_rng ty
       val {Tyop,Args} = dest_type dty
       val clist = map (rand o lhs o #2 o strip_forall) (strip_conj Body)
       fun mk_cfun ctm (nv,away) =
          let val (c,args) = strip_comb ctm
              val fty = itlist (curry (op -->)) (map type_of args) rty
              val vname = if (length args = 0) then "v" else "f"
              val v = num_variant away (mk_var{Name = vname, Ty = fty})
          in (v::nv, v::away)
          end
      val arg_list = rev(fst(rev_itlist mk_cfun clist ([],free_varsl clist)))
      val v = mk_var{Name = Tyop^"_case",
                     Ty = itlist (curry (op -->)) (map type_of arg_list) ty}
      val preamble = list_mk_comb(v,arg_list)
      fun clause (a,c) = mk_eq{lhs = mk_comb{Rator=preamble,Rand=c},
                               rhs = list_mk_comb(a, rev(free_vars c))}
      val defn = list_mk_conj (map clause (zip arg_list clist))
   in
    new_recursive_definition
        {name=Tyop^"_case_def", fixity=Prefix, rec_axiom=ax,
         def = defn}
   end;


(*---------------------------------------------------------------------------*
 * Term size, as a function of types. Types not registered in gamma are      *
 * translated into the constant function that returns 0. The function theta  *
 * maps a type variables (say 'a) to a term variable "f" of type 'a -> num.  *
 * The function gamma maps type operator "ty" to term "ty_size".             *
 *                                                                           *
 * When actually building a measure function, the behaviour of theta is      *
 * changed to be such that it maps type variables to the constant function   *
 * that returns 0.                                                           *
 *---------------------------------------------------------------------------*)

val numty = mk_type{Tyop="num", Args=[]};
val Zero = mk_const{Name="0", Ty=numty};
val One  = mk_const{Name="1", Ty=numty};
val Plus = mk_const{Name="+", Ty=Type`:num -> num -> num`};

fun mk_plus x y = list_mk_comb(Plus,[x,y]);

fun K0 ty = mk_abs{Bvar=mk_var{Name="v",Ty=ty}, Body=Zero};

fun drop [] ty = fst(dom_rng ty)
  | drop (_::t) ty = drop t (snd(dom_rng ty));
       

(*---------------------------------------------------------------------------*
 * The following is the prototype code for tysize and may give some insight  *
 * into what it does.                                                        *
 *                                                                           *
 *   fun tysize ty =                                                         *
 *    if is_vartype ty then Term`\x:^(ty_antiq ty).0`                        *
 *    else case (dest_type ty)                                               *
 *          of {Tyop="bool",...}         => Term`\x:bool. 0`                 *
 *           | {Tyop="fun",...}          => Term`\x:^(ty_antiq ty). 0`       *
 *           | {Tyop="num",...}          => Term`\x:num. x`                  *
 *           | {Tyop="list",Args=[ety]}  => Term`list_size ^(tysize ety)`    *
 *           | {Tyop="prod",Args=[ty1,ty2]} => Term                          *
 *               `\(x:^(ty_antiq ty1),                                       *
 *                 (y:^(ty_antiq ty2))). ^(tysize ty1) x + ^(tysize ty2) y`  *
 *           | _ => raise ERR "tysize" "unknown type";                       *
 *                                                                           *
 *---------------------------------------------------------------------------*)

fun tysize (theta,gamma) ty  =
   case (theta ty)
     of SOME fvar => fvar
      | NONE =>
         let val {Tyop,Args} = dest_type ty
         in case (gamma Tyop)
             of SOME f => 
                  let val vty = drop Args (type_of f)
                      val sigma = Type.match_type vty ty
                  in list_mk_comb(inst sigma f, 
                                  map (tysize (theta,gamma)) Args)
                  end
              | NONE => K0 ty
         end;

local fun num_variant vlist v =
        let val counter = ref 1
            val {Name,Ty} = dest_var v
            fun pass str list = 
              if (mem str list) 
              then ( counter := !counter + 1;
                     pass (Name^Lib.int_to_string(!counter)) list)
              else str
        in 
          mk_var{Name=pass Name (map (#Name o dest_var) vlist),  Ty=Ty}
        end
in
fun mk_ty_fun vty (V,away) = 
    let val fty = vty --> numty
        val v = num_variant away (mk_var{Name="f", Ty=fty})
    in 
       (v::V, v::away)
    end
end;


fun define_size ax tysize_env =
 let val exu = snd(strip_forall(concl ax))
     val {Rator,Rand} = dest_comb exu
     val {Name = "?!",...} = dest_const Rator
     val {Bvar,Body} = dest_abs Rand
     val (dty,rty) = Type.dom_rng (type_of Bvar)
     val {Tyop,Args} = dest_type dty
     val clist = map (rand o lhs o #2 o strip_forall) (strip_conj Body)
     val arglist = rev(fst(rev_itlist mk_ty_fun Args ([],free_varsl clist)))
     val v = mk_var{Name = Tyop^"_size",
                    Ty = itlist (curry op-->) (map type_of arglist) 
                                (dty --> numty)}
     val preamble = list_mk_comb(v,arglist)
     val f0 = zip Args arglist
     fun theta tyv = case (assoc1 tyv f0) of SOME(_,x) => SOME x | _ => NONE
     fun gamma str = 
          if (str = Tyop) then SOME v 
          else tysize_env str
     fun mk_app x = mk_comb{Rator=tysize(theta,gamma) (type_of x), Rand=x}
     fun capp2rhs capp = 
          case snd(strip_comb capp)
           of [] => Zero
            | L  => end_itlist mk_plus (One::map mk_app L)
     fun clause c = mk_eq{lhs=mk_comb{Rator=preamble,Rand=c},rhs=capp2rhs c}
     val defn = list_mk_conj (map clause clist) 
 in
    new_recursive_definition
        {name=Tyop^"_size_def", fixity=Prefix, rec_axiom=ax,
         def = defn}
 end
 handle HOL_ERR _ => raise ERR "define_size" "failed";


fun tysize_env db = join Facts.size_of (Facts.get db);


fun type_size db ty = 
  let fun theta ty = 
        if (is_vartype ty) then SOME (K0 ty) else NONE 
  in 
    tysize (theta,tysize_env db) ty
  end;


fun primHol_datatype db q = 
  let val {ty_name, clauses} = Parse.type_spec_parser q
      fun prefix{constructor,args} = 
          {constructor=constructor, args=args, fixity=Prefix}
      val clauses' = map prefix clauses
      open Define_type
      val ax = dtype{clauses=clauses',save_name=ty_name,ty_name=ty_name}
      val tyinfo = Facts.mk_facts
           {datatype_ax = ax,
            case_def = define_case ax,
            one_one  = SOME(prove_constructors_one_one ax) 
                       handle HOL_ERR _ => NONE,
            distinct = SOME(prove_constructors_distinct ax) 
                       handle HOL_ERR _ => NONE}
      val size_def = define_size (Facts.axiom_of tyinfo) (tysize_env db)
  in
      Facts.put_size (defn_const size_def,size_def) tyinfo
  end 
  handle e as HOL_ERR _ => Raise e;



(*---------------------------------------------------------------------------*
 * Create the database.                                                      *
 *---------------------------------------------------------------------------*)

local val dBase = ref Facts.empty	
in
fun theFactBase() = !dBase;

fun add facts = (dBase := Facts.add (theFactBase()) facts);

end;

fun get s = Facts.get (theFactBase()) s;

fun Hol_datatype q = add (primHol_datatype (theFactBase()) q);

(*---------------------------------------------------------------------------*
 * Install datatype facts for bool, pairs, numbers, and lists                *
 * into theFactBase.                                                         *
 *---------------------------------------------------------------------------*)

open boolTheory;

val boolAxiom = prove(Term`!(e0:'a) e1. ?!fn. (fn T = e0) /\ (fn F = e1)`,
SUBST_TAC[INST_TYPE[Type`:'a` |-> Type`:bool -> 'a`] EXISTS_UNIQUE_DEF]
  THEN BETA_TAC THEN BETA_TAC THEN REPEAT GEN_TAC THEN CONJ_TAC THENL
  [EXISTS_TAC(Term`\x. (x=T) => (e0:'a) | e1`) THEN BETA_TAC
     THEN SUBST_TAC[EQT_INTRO(REFL(Term`T`)),
                    EQF_INTRO(CONJUNCT2(BOOL_EQ_DISTINCT))]
     THEN SUBST_TAC(CONJUNCTS (SPECL [Term`e0:'a`, Term`e1:'a`] COND_CLAUSES))
     THEN CONJ_TAC THEN REFL_TAC,
   CONV_TAC (DEPTH_CONV FUN_EQ_CONV) THEN REPEAT STRIP_TAC 
     THEN BOOL_CASES_TAC(Term`b:bool`)
     THEN REPEAT (POP_ASSUM SUBST1_TAC) THEN REFL_TAC]);

val bool_case_def = 
  let val [x,y,_] = #1(strip_forall(concl boolTheory.bool_case_DEF))
      val thm = SPEC y (SPEC x boolTheory.bool_case_DEF)
      val thmT = SPEC (Parse.Term`T`) thm
      val thmF = SPEC (Parse.Term`F`) thm
      val CC = SPECL [Term`x:'a`, Term`y:'a`] COND_CLAUSES
      val thmT' = SUBS[CONJUNCT1 CC] thmT
      val thmF' = SUBS[CONJUNCT2 CC] thmF
      fun gen th = GEN x (GEN y th)
  in
   CONJ (gen thmT') (gen thmF')
  end;

val bool_case_rw = prove(Term`!x y. bool_case x x y = (x:'a)`,
 REPEAT GEN_TAC 
   THEN BOOL_CASES_TAC (Term`y:bool`) 
   THEN Rewrite.ASM_REWRITE_TAC[bool_case_def]);

val bool_info = 
   Facts.mk_facts
        {datatype_ax = boolAxiom,
         case_def = bool_case_def,
         distinct = SOME (CONJUNCT1 BOOL_EQ_DISTINCT),
         one_one = (NONE:thm option)};

val bool_size_info =
       (Term`bool_case 0 0`, bool_case_rw)

val bool_info' = Facts.put_size bool_size_info bool_info


(*---------------------------------------------------------------------------*
 * Products                                                                  *
 *---------------------------------------------------------------------------*)
val prod_info = 
         Facts.mk_facts
              {datatype_ax = pairTheory.pair_Axiom,
               case_def = pairTheory.UNCURRY_DEF,
               distinct = NONE, 
               one_one = SOME pairTheory.CLOSED_PAIR_EQ}

val prod_size_info =
       (Parse.Term`\f g. UNCURRY(\x y. f (x:'a) + g (y:'b))`,
        pairTheory.UNCURRY_DEF)

val prod_info' = Facts.put_size prod_size_info prod_info


(*---------------------------------------------------------------------------*
 * Numbers                                                                   *
 *---------------------------------------------------------------------------*)
val num_info = 
         Facts.mk_facts
              {datatype_ax = prim_recTheory.num_Axiom,
               case_def = arithmeticTheory.num_case_def,
               distinct = SOME numTheory.NOT_SUC, 
               one_one = SOME prim_recTheory.INV_SUC_EQ}

val num_size_info = 
        (Parse.Term`I:num->num`, combinTheory.I_THM)

(*---------------------------------------------------------------------------*
 * Note. Adding in num_CASES, because the generated nchotomy theorem has     *
 * primes in the names, which is ugly in use.                                *
 *---------------------------------------------------------------------------*)
val num_info' = Facts.put_nchotomy
                  arithmeticTheory.num_CASES
                   (Facts.put_size num_size_info num_info);


(*---------------------------------------------------------------------------*
 * Lists                                                                     *
 *---------------------------------------------------------------------------*)
val list_info = 
        Facts.mk_facts
             {datatype_ax = listTheory.list_Axiom,
              case_def = listTheory.list_case_def,
              distinct = SOME listTheory.NOT_NIL_CONS, 
              one_one = SOME listTheory.CONS_11}

val list_size_info = 
       (Parse.Term`list_size:('a->num) -> 'a list -> num`, 
        listTheory.list_size)

val list_info' = Facts.put_size list_size_info list_info

(*---------------------------------------------------------------------------*
 * Options                                                                   *
 *---------------------------------------------------------------------------*)

val option_info = 
     Facts.mk_facts
         {datatype_ax = 
            let val fe = fst(strip_forall(concl optionTheory.option_Axiom))
            in GENL (rev fe)
                    (Rewrite.PURE_ONCE_REWRITE_RULE [boolTheory.CONJ_SYM]
                         (SPEC_ALL optionTheory.option_Axiom))
            end,
          case_def = optionTheory.option_CASE_DEF,
          distinct = SOME optionTheory.NOT_NONE_SOME, 
          one_one = SOME optionTheory.SOME_11}

val option_size_info = 
       (Parse.Term`\f:'a->num. option_CASE 0 (\x. 1 + f x)`, 
        optionTheory.option_CASE_DEF);

val option_info' = Facts.put_size option_size_info option_info

val _ = map add [bool_info', prod_info', num_info', list_info', option_info']

end;
