(*---------------------------------------------------------------------------*
 * Building records of facts about datatypes. This is all purely functional, *
 * i.e., there are no side-effects happening anywhere here.                  *
 *---------------------------------------------------------------------------*)

structure Facts :> Facts =
struct

open HolKernel Drule Conv Prim_rec;
infix ## |->;

 type term = Term.term
 type thm = Thm.thm


fun ERR f s = 
  HOL_ERR{origin_structure = "Facts", 
          origin_function=f,message = s};


datatype tyinfo =
  FACTS of  string *
            {axiom:thm, 
             case_const:term, 
             case_def:thm, 
             case_cong:thm,
             constructors:term list,
             induction:thm,
             nchotomy:thm, 
             size : (term * thm) option,
             distinct:thm option, 
             one_one:thm option, 
             simpls:thm list};

(*---------------------------------------------------------------------------
 * Projections.
 *---------------------------------------------------------------------------*)
fun constructors_of (FACTS(_, {constructors,...})) = constructors
fun case_const_of (FACTS(_,{case_const,...})) = case_const
fun case_cong_of (FACTS(_,{case_cong,...})) = case_cong
fun case_def_of (FACTS(_,{case_def,...})) = case_def
fun induction_of (FACTS(_,{induction,...})) = induction
fun nchotomy_of (FACTS(_,{nchotomy,...})) = nchotomy
fun distinct_of (FACTS(_,{distinct,...})) = distinct
fun one_one_of (FACTS(_,{one_one,...})) = one_one
fun simpls_of (FACTS(_,{simpls,...})) = simpls
fun axiom_of (FACTS(_,{axiom,...})) = axiom
fun size_of (FACTS(_,{size,...})) = size
fun ty_name_of (FACTS(s,_)) = s

(*---------------------------------------------------------------------------*
 * Alterations.                                                              *
 *---------------------------------------------------------------------------*)
fun put_nchotomy th (FACTS(s,
 {axiom, case_const,case_cong,case_def,constructors,
  induction, nchotomy, distinct, one_one, simpls, size}))
  =
  FACTS(s, {axiom=axiom, case_const=case_const,
            case_cong=case_cong,case_def=case_def, constructors=constructors,
            induction=induction, nchotomy=th, distinct=distinct, 
            one_one=one_one, simpls=simpls, size=size})

fun put_induction th (FACTS(s,
 {axiom, case_const,case_cong,case_def,constructors,
  induction, nchotomy, distinct, one_one, simpls, size}))
  =
  FACTS(s, {axiom=axiom, case_const=case_const,
            case_cong=case_cong,case_def=case_def, constructors=constructors,
            induction=th, nchotomy=nchotomy, distinct=distinct, 
            one_one=one_one, simpls=simpls, size=size})

fun put_simpls thl (FACTS(s,
 {axiom, case_const,case_cong,case_def,constructors,
  induction, nchotomy, distinct, one_one, simpls, size}))
  =
  FACTS(s, {axiom=axiom, case_const=case_const,
            case_cong=case_cong,case_def=case_def,constructors=constructors,
            induction=induction, nchotomy=nchotomy, distinct=distinct, 
            one_one=one_one, simpls=thl, size=size});


fun put_size (size_tm,size_rw) (FACTS(s,
 {axiom, case_const,case_cong,case_def,constructors,
  induction, nchotomy, distinct, one_one, simpls, size}))
  =
  FACTS(s, {axiom=axiom, case_const=case_const,
            case_cong=case_cong,case_def=case_def,constructors=constructors,
            induction=induction, nchotomy=nchotomy, distinct=distinct, 
            one_one=one_one, simpls=simpls, size=SOME(size_tm,size_rw)});



(*---------------------------------------------------------------------------
 * Defining and proving case congruence:
 *
 *    |- (M = M') /\
 *       (!x1,...,xk. (M' = C1 x1..xk) ==> (f1 x1..xk = f1' x1..xk)) 
 *        /\ ... /\
 *       (!x1,...,xj. (M' = Cn x1..xj) ==> (fn x1..xj = fn' x1..xj)) 
 *       ==>
 *      (ty_case f1..fn M = ty_case f1'..fn' m')
 *
 *---------------------------------------------------------------------------*)
fun case_cong_term case_def =
 let val clauses = (strip_conj o concl) case_def
     val clause1 = Lib.trye hd clauses
     val left = (#lhs o dest_eq o #2 o strip_forall) clause1
     val ty = type_of (rand left)
     val allvars = all_varsl clauses
     val M = variant allvars (mk_var{Name = "M", Ty = ty})
     val M' = variant (M::allvars) (mk_var{Name = "M", Ty = ty})
     fun mk_clause clause =
       let val {lhs,rhs} = (dest_eq o #2 o strip_forall) clause
           val func = (#1 o strip_comb) rhs
           val (constr,xbar) = strip_comb(rand lhs)
           val {Name,Ty} = dest_var func
           val func' = variant allvars (mk_var{Name=Name^"'", Ty=Ty})
       in (func',
           list_mk_forall
           (xbar, mk_imp{ant = mk_eq{lhs=M',rhs=list_mk_comb(constr,xbar)},
                         conseq = mk_eq{lhs=list_mk_comb(func,xbar),
                                        rhs=list_mk_comb(func',xbar)}}))
       end
     val (funcs',clauses') = unzip (map mk_clause clauses)
     val lhsM = mk_comb{Rator=rator left, Rand=M}
     val c = #1(strip_comb left)
 in
 mk_imp{ant = list_mk_conj(mk_eq{lhs=M, rhs=M'}::clauses'),
        conseq = mk_eq{lhs=lhsM, rhs=list_mk_comb(c,(funcs'@[M']))}}
 end;
  
(*---------------------------------------------------------------------------*
 *                                                                           *
 *        A, v = M[x1,...,xn] |- N                                           *
 *  ------------------------------------------                               *
 *     A, ?x1...xn. v = M[x1,...,xn] |- N                                    *
 *                                                                           *
 *---------------------------------------------------------------------------*)
fun EQ_EXISTS_LINTRO (thm,(vlist,theta)) = 
  let val [veq] = filter (can dest_eq) (hyp thm)
      fun CHOOSER v (tm,thm) = 
        let val w = (case (subst_assoc (fn w => v=w) theta)
                      of SOME w => w
                       | NONE => v)
            val ex_tm = mk_exists{Bvar=w, Body=tm}
        in (ex_tm, CHOOSE(w, ASSUME ex_tm) thm)
        end
  in snd(itlist CHOOSER vlist (veq,thm))
  end;


fun OKform case_def = 
  let val clauses = (strip_conj o concl) case_def
      val left = (rator o #lhs o dest_eq o #2 o strip_forall) 
                 (Lib.trye hd clauses)
      val opvars = #2 (strip_comb left)
      fun rhs_head c = fst(strip_comb(rhs(snd(strip_forall c))))
      val rhs_heads = map rhs_head clauses
      fun check [] = true
        | check ((x,y)::rst) = (x=y) andalso check rst
  in 
     check (zip opvars rhs_heads)
  end    


fun case_cong_thm nchotomy case_def =
 let val _ = assert OKform case_def
     val gl = case_cong_term case_def
     val {ant,conseq} = dest_imp gl
     val imps = CONJUNCTS (ASSUME ant)
     val M_eq_M' = hd imps
     val {lhs=M, rhs=M'} = dest_eq (concl M_eq_M')
     fun get_asm tm = (#ant o dest_imp o #2 o strip_forall) tm handle _ => tm
     val case_assms = map (ASSUME o get_asm o concl) imps
     val {lhs=lconseq, rhs=rconseq} = dest_eq conseq
     val lconseq_thm = SUBST_CONV[M |-> M_eq_M'] lconseq lconseq
     val lconseqM' = rhs(concl lconseq_thm)
     val nchotomy' = ISPEC M' nchotomy
     val disjrl = map ((I##rhs) o strip_exists)	(strip_disj (concl nchotomy'))
     fun zot icase_thm (iimp,(vlist,disjrhs)) =
       let val lth = Rewrite.REWRITE_CONV[icase_thm, case_def] lconseqM'
           val rth = Rewrite.REWRITE_CONV[icase_thm, case_def] rconseq
           val theta = Term.match_term disjrhs
                     ((rhs o #ant o dest_imp o #2 o strip_forall o concl) iimp)
           val th = MATCH_MP iimp icase_thm
           val th1 = TRANS lth th
       in (TRANS th1 (SYM rth), (vlist, #1 theta))
       end
     val thm_substs = map2 zot (tl case_assms) (zip (tl imps) disjrl)
     val aag = map (TRANS lconseq_thm o EQ_EXISTS_LINTRO) thm_substs
 in
   GEN_ALL(DISCH_ALL(DISJ_CASESL nchotomy' aag))
 end 
 handle HOL_ERR _ => raise ERR "case_cong_thm" "construction failed";


(*---------------------------------------------------------------------------*
 * Returns the datatype name and the constructors. The code is a copy of     *
 * the beginning of "define_case".                                           *
 *---------------------------------------------------------------------------*)
fun ax_info ax = 
  let val exu = snd(strip_forall(concl ax))
      val {Rator,Rand} = dest_comb exu
      val {Name = "?!",...} = dest_const Rator
      val {Bvar,Body} = dest_abs Rand
      val (dty,_) = Type.dom_rng (type_of Bvar)
      val {Tyop,Args} = dest_type dty
      val clist = map (rand o lhs o #2 o strip_forall) (strip_conj Body)
  in
   (Tyop,  map (fst o strip_comb) clist)
  end;

val defn_const = 
  #1 o strip_comb o lhs o #2 o strip_forall o hd o strip_conj o concl;

(*---------------------------------------------------------------------------*
 * The 11 theorem and distinctness theorem have to be given (currently)      *
 * because the routines for proving them use numbers, which might be         *
 * unnecessary in some formalizations. The size field is not filled by       *
 * mk_facts, since that operation requires access to the current fact        *
 * database.                                                                 *
 *---------------------------------------------------------------------------*)
fun mk_facts {datatype_ax, case_def, one_one, distinct} =
  let val induct_thm = prove_induction_thm datatype_ax
      val nchotomy = prove_cases_thm induct_thm
      val (ty_name,constructors) = ax_info datatype_ax
      val inj = case one_one of NONE => [] | SOME x => [x]
      val D = case distinct of NONE => [] | SOME x => CONJUNCTS x
  in
   FACTS(ty_name,
     {constructors = constructors,
      case_const = defn_const case_def,
      case_def   = case_def,
      case_cong  = case_cong_thm nchotomy case_def,
      induction  = induct_thm,
      nchotomy   = nchotomy, 
      one_one    = one_one,
      distinct   = distinct,
      simpls     = case_def :: (D@map GSYM D@inj),
      size       = NONE,
      axiom      = datatype_ax})
  end;

(*---------------------------------------------------------------------------*
 * Databases of facts.                                                       *
 *---------------------------------------------------------------------------*)
type factBase = tyinfo Binaryset.set

val empty = 
   Binaryset.empty (fn (f1,f2) => 
       String.compare (ty_name_of f1, ty_name_of f2));

fun get db s = Binaryset.find (fn f => (s = ty_name_of f)) db;
fun add db f = Binaryset.add(db,f);
val elts = Binaryset.listItems;

end;
