structure EI : EI =

struct

type classifier_kind = string;

type vdecider = bool *
                (term list * term -> thm * (PrettyPrint.ppstream -> unit));

fun error (func,mesg) = HOL_ERR {origin_structure="EI",
                                 origin_function = func,
                                 message = mesg};

fun namesrc() = 
  Lib.mk_istream (fn i => i+1) 0 (fn i => "th"^Lib.int_to_string i);


(*---------------------------------------------------------------------------*
 * Takes a goal and classifies it with `decider',                            *
 * then gets a new name for the fact, then adds the fact into the fact       *
 * database. Finally returns the name and the fact.                          *
 *                                                                           *
 * Can't make my mind up whether this should go here or in vistactic. Notice *
 * that the partial evaluation of this function is crucially important.      *
 * The function should be applied to () only once in an extended interaction.*
 *---------------------------------------------------------------------------*)
fun classify0 () =
   let val names = namesrc()
   in fn vdeciderf => fn kind => fn gl as (tms,tm) =>
      let val (allow_refutation,decider) = vdeciderf kind
          val fact =
             Fact.Proved (decider gl)
             handle _ => if allow_refutation andalso
                            (can decider (tms,mk_neg tm))
                         then Fact.Refuted gl
                         else Fact.CantProve gl
          val s = Lib.state names before (Lib.next names; ())
          val _ = Fact.add (s,fact)
      in
       (s,fact)
      end
   end;

val std_decider =
   (true,
    fn gl =>
       (CallClam.VisibleTactic.DECIDE_GOAL gl,
        fn ppstrm =>
           (PrettyPrint.begin_block ppstrm PrettyPrint.CONSISTENT 2;
            PrettyPrint.add_string ppstrm "VisibleTactic.QDECIDE_GOAL";
            PrettyPrint.add_break ppstrm (1,0);
            CallClam.VisibleTactic.pp_quoted_goal ppstrm gl;
            PrettyPrint.end_block ppstrm)));

local

structure R = retrieveLib;

fun sym tm = mk_eq {lhs = rhs tm,rhs = lhs tm};

fun check_theorem tm (kind,thy,name,th) =
   let val (vars,tm') = strip_forall tm
       val (_,ptm) = strip_forall (concl th)
       val sy = (match_term ptm tm'; false)
                handle HOL_ERR _ =>
                (match_term ptm (sym tm'); true)
   in  (kind,thy,name,th,sy)
   end;

fun find_theorem (_,tm) =
   let val (_,tm') = strip_forall tm
       fun subpat t = R.has_body (Parse.term_parser`conc:bool`,t)
       val pat = if (is_eq tm')
                 then R.Orelse (subpat tm',subpat (sym tm'))
                 else subpat tm'
       val thy = current_theory ()
       val src = R.Paths [R.Ancestors ([thy],[])]
       (* Should filter out non-conforming theorems from the new step *)
       (* before continuing the search.                               *)
       fun search step =
          hd (mapfilter (check_theorem tm) (R.show_step step))
          handle Empty =>
          search (R.search_until_find step)
   in  search (R.find_theorems pat src)
       handle HOL_ERR _ =>
       raise error ("find_theorem","no match")
   end;

fun input message valid_chars =
   let fun try_again () = input message valid_chars
       val _ = (Lib.say message; TextIO.flushOut TextIO.stdOut)
       val s = TextIO.inputLine TextIO.stdIn
       val c = String.sub (s,0)
   in  if (String.size s <= 2) andalso (Lib.mem c valid_chars)
       then c
       else try_again ()
   end;

fun user_selection ([],tm) =
   let (*
       val valid_chars = [#"S",#"s",#"R",#"r"]
       val c = input "(S)earch for a matching theorem, or (R)eject? "
                  valid_chars
       *)
       val c = #"s"
   in  if (c = #"s") orelse (c = #"S")
       then let val (kind,thy,name,th,sy) = find_theorem ([],tm)
                val (kind_str,kind_fn) =
                   assoc kind
                      [(R.Axiom,("axiom","axiom")),
                       (R.Definition,("definition","definition")),
                       (R.Theorem,("theorem","theorem"))]
                val pp =
                   fn ppstrm =>
                      (PrettyPrint.begin_block ppstrm
                          PrettyPrint.INCONSISTENT 2;
                       PrettyPrint.add_string ppstrm kind_fn;
                       PrettyPrint.add_break ppstrm (1,0);
                       PrettyPrint.add_string ppstrm (Lib.quote thy);
                       PrettyPrint.add_break ppstrm (1,0);
                       PrettyPrint.add_string ppstrm (Lib.quote name);
                       PrettyPrint.end_block ppstrm)
            in  Lib.say ("Found matching " ^ kind_str ^ " `" ^ name ^ "'" ^
                         " in theory segment `" ^ thy ^ "'\n\n");
                if sy
                then (GSYM th,
                      fn ppstrm =>
                         (PrettyPrint.begin_block ppstrm
                             PrettyPrint.CONSISTENT 2;
                          PrettyPrint.add_string ppstrm "GSYM";
                          PrettyPrint.add_break ppstrm (1,0);
                          PrettyPrint.add_string ppstrm "(";
                          pp ppstrm;
                          PrettyPrint.add_string ppstrm ")";
                          PrettyPrint.end_block ppstrm))
                else (th,pp)
            end
            handle HOL_ERR _ =>
            (Lib.say "No matching theorem found -- rejecting.\n\n";
             raise Fail "No theorem found")
       else (Lib.say "\n"; raise Fail "Rejected by user")
   end
  | user_selection _ = raise Fail "Conjecture shouldn't have hypotheses";

in

val wave_decider =
   (false,
    fn gl => (Lib.say "Clam has conjectured the following lemma:\n";
              Lib.say "  "; print_term (#2 gl); Lib.say "\n";
              let val thpp = #2 std_decider gl
              in  Lib.say "Proved using decision procedures.\n\n";
                  thpp
              end
              handle _ => user_selection gl));

end;

fun std_classifier () =
   classify0 () (fn "wave" => wave_decider | _ => std_decider);

(* The classifier that always fails. *)
fun alwaysFail () =
   classify0 ()
      (fn _ => (false,fn gl => (raise Fail "alwaysFail",
                                Lib.C PrettyPrint.add_string "NO_TAC")));


local open Fact
in
fun proved (NEW (_,Proved _)) = true
  | proved (OLD (_,Proved _)) = true
  | proved _ = false

fun NEWof (NEW x) = x
  | NEWof (OLD x) = raise Fail "NEWof";
end;


(*---------------------------------------------------------------------------*
 * Embark on an extended interaction.                                        *
 *---------------------------------------------------------------------------*)
fun EI0 class f g =
 let val _ = (Fact.clear();  (* reset the DB of facts *)
              CallClam.delete_all_facts ())
     val mk_tacAST = CallClam.CLAM_TAC_AST0 CallClam.prove_goal (class())
     fun roll n =
       let val _ = Portable.output(Portable.std_out, 
                    "Round "^Lib.int_to_string n^"\n")
           val (tacAST, facts) =  mk_tacAST g
       in 
          if (Lib.all proved facts)  (* done *)
          then f (tacAST, map Fact.dest_age facts) g
          else (map CallClam.send_fact (mapfilter NEWof facts);
                roll (n+1))
       end
 in 
     roll 0
 end;

fun EI f g = EI0 std_classifier f g;

end;
