(*---------------------------------------------------------------------------
 * FILE: plan2tactic.sml
 *
 * This structure provides code for mapping Clam proof plans into Client
 * tactics.
 *---------------------------------------------------------------------------*)

functor PLAN2TAC(structure VisibleTactic : VisibleTactic_sig
                 structure Clam2Client   : Clam2Client
                 structure ClamDatabase  : ClamDatabase_sig
                 sharing type VisibleTactic.term = Clam2Client.client_term) 

  : Plan2Tactic_sig =

struct

structure VisibleTactic = VisibleTactic;

open VisibleTactic;
open Clam2Client;
open PrettyPrint;
open Rsyntax;

fun tac_err(func,mesg) = Fail (func^": "^mesg);

fun quote s = String.concat ["\"", s, "\""];

fun string_to_direction "left"  = LEFT
  | string_to_direction "right" = RIGHT
  | string_to_direction  s      = raise tac_err("string_to_direction",quote s);


(*---------------------------------------------------------------------------*
 * Abstract syntax trees for compound tactics. A tactic representation       *
 * carries a prettyprinter at each leaf node, so that concrete syntax can    *
 * be generated easily.                                                      *
 *---------------------------------------------------------------------------*)
nonfix THEN THENL

datatype tacticAST = NO_TAC
                   | ALL_TAC
                   | PRIM  of VisibleTactic.vtactic
                   | THEN  of tacticAST * tacticAST
                   | THENL of tacticAST * tacticAST list


(*---------------------------------------------------------------------------
 * We translate methods to tacticASTs; from these we can generate both 
 * concrete syntax and code. 
 *---------------------------------------------------------------------------*)
local open ClamAST
      fun unarg(Method m) = m
        | unarg x = raise tac_err("mk_tac","unarg")
      fun client_fact_id s = #1 (ClamDatabase.find_fact s)
datatype rule_type = Eq | Imp
fun wave_tac (Method(Methods[MethodApp(s,[]), 
                             Equ(_,Method(MethodApp(lr,[])))]))
      = (Eq,s,lr)
  | wave_tac (Method(Methods[MethodApp(s,[]),
                             MethodApp complementary,
                             Equ(_,Method(MethodApp(lr,[])))]))
      = (Eq,s,lr)
  | wave_tac (Method(Methods[MethodApp(s,[]), 
                             MethodApp("equiv",[Method(MethodApp(lr,[]))])]))
      = (Eq,s,lr)
  | wave_tac (Method(Methods[MethodApp(s,[]),
                             MethodApp complementary,
                             MethodApp("equiv",[Method(MethodApp(lr,[]))])]))
      = (Eq,s,lr)
  | wave_tac (Method(Methods[MethodApp(s,[]), 
                             MethodApp("imp",[Method(MethodApp(lr,[]))])]))
      = (Imp,s,lr)
  | wave_tac _ = raise tac_err("mk_tac","wave_tac")
in
fun make_tac classifier ts X = 
 let val r = (ref []: (string * Fact.fact) Fact.age list ref)
     fun mk_tac X =
case X of
    (Induction(x,[lambda],_,itms)) =>
      let val (Method(MethodApp("lemma",[Method(MethodApp(s,[]))]))) = x
          fun ind_vars (Lambda(Inhabit idty,body)) = idty :: ind_vars body
            | ind_vars body = []
          fun convert (id,ty) =
              let val hty = mk_client_type ts ty
              in  (Clam2Client.mk_client_var
                      {Name=client_var_name ts id, Ty=hty},hty)
              end
          val (ivs,htys) = unzip (map convert (ind_vars lambda))
          val its =
             map (fn InductionTerms (tms,_) =>
                        map2 (mk_client_term ts) tms htys)
                itms
      in  PRIM (IND s ivs its)
      end
  | (Induction(x,lambdas,n,itms)) =>
      let val (Method(MethodApp("lemma",[Method(MethodApp(s,[]))]))) = x
          fun ty_of_lam (Lambda(Inhabit (_,ty),body)) =
                 ClamAST.Prop (ty,"=>",ty_of_lam body)
            | ty_of_lam _ = ClamAST.Identifier "holbool" (* Hack! *)
          fun arg_tys (Lambda(Inhabit (_,ty),body)) =
                 mk_client_type ts ty :: arg_tys body
            | arg_tys _ = []
          val ls = map (fn lambda =>
                           mk_client_term ts lambda
                              (mk_client_type ts (ty_of_lam lambda))) lambdas
          val htys = map arg_tys lambdas
          val its =
             map (fn InductionTerms (tms,n) =>
                        map2 (mk_client_term ts) tms (Lib.el n htys))
                itms
      in  PRIM (MULTI_PRED_IND n s ls its)
      end
  | (MethodApp("ind_strat",        [Method m]))          => mk_tac m
  | (MethodApp("induction_mutual", [Method m]))          => mk_tac m
  | (MethodApp("elementary",       [Method m]))          => PRIM ELEMENTARY
  | (MethodApp("base_case",        [Method(Methods L)])) => mk_tacl L
  | (MethodApp("base_case",        [Method m]))          => mk_tac m
  | (MethodApp("sym_eval",         [Method(Methods L)])) => mk_tacl L
  | (MethodApp("sym_eval",         [Method m]))          => mk_tac m
  | (MethodApp("intro",            [Method(Methods L)])) => mk_tacl L
  | (MethodApp("intro",            [Method m]))          => mk_tac m
  | (MethodApp("step_case",        [Method(Methods L)])) => mk_tacl L
  | (MethodApp("step_case",        [Method m]))          => mk_tac m
  | (MethodApp("normalize",        [Method(Methods L)])) => mk_tacl L
  | (MethodApp("normalize",        [Method m]))          => mk_tac m
  | (MethodApp("normalize_term",   [Method(Methods L)])) => mk_tacl L
  | (MethodApp("normalize_term",   [Method m]))          => mk_tac m
  | (MethodApp("ripple",           [dir, Method m]))     => mk_tac m
  | (MethodApp("ripple_and_cancel",[Method(Methods L)])) => mk_tacl L

  | (New idlist) => PRIM (GENL_TAC idlist)
  | (Generalise (term,Inhabit (id,typ))) => 
       let val hty = mk_client_type ts typ
       in PRIM(GENERALISE(mk_client_term ts term hty, 
                          Clam2Client.mk_client_var{Name=id, Ty=hty})) 
       end
  (* Reduction: the next two differ only in "equ" vs. "equiv" *)
  | (MethodApp("reduction",[NumList path, 
       Method(Methods[MethodApp(s,[]),
         Equ(_,Method(MethodApp(lr,[])))])])) 
     => PRIM(OCC_RW (s, path, string_to_direction lr))
  | (MethodApp("reduction",[NumList path, 
       Method(Methods[MethodApp(s,[]),
         MethodApp("equiv",[Method(MethodApp(lr,[]))])])]))
     => PRIM(OCC_RW (s, path, string_to_direction lr))
  | (MethodApp("reduction",[NumList path, 
       Method(Methods[MethodApp(s,[]),
         MethodApp("imp",[Method(MethodApp(lr,[]))])])])) 
     => PRIM(ANT_RW (s, path, string_to_direction lr))
  | (MethodApp("wave",[dir, NumList path, m, _])) 
     => let val (rty,s,lr) = wave_tac m
            val vtac = case rty of Eq => OCC_RW | Imp => ANT_RW
        in  PRIM(vtac (s, path, string_to_direction lr))
        end

  | (MethodApp("unblock_then_fertilize",[str,Method m])) => mk_tac m
  | (MethodApp("unblock_then_wave",[str,Method m]))      => mk_tac m
  | (MethodApp("unblock_fertilize_lazy",[Method m]))     => mk_tac m
  | (MethodApp("unblock_lazy",[Method m]))               => mk_tac m
  | (MethodApp("unblock",
       [Method (MethodApp ("wave_front",[])),
        NumList path,
         Method(Methods[MethodApp (s,[]),
              MethodApp("equiv",[Method (MethodApp (lr,[]))])])]))
      => PRIM(OCC_RW (s, path, string_to_direction lr))

  | (MethodApp("fertilize",
         [Method(MethodApp("strong",[])),Method m])) => mk_tac m
  | (StrongFertilize(NumList path,_,_))              => PRIM (SFERT path)
  | (MethodApp("fertilize",[str,Method m]))   => mk_tac m
  | (MethodApp("fertilize_then_ripple", l))   => mk_tacl (map unarg l)
  | (MethodApp("fertilize_left_or_right",
                 [dir, Method(Methods L)])) => mk_tacl L

  | (WeakFertilize  (* fertilizing an "=" goal *)
         (Method(MethodApp(lr,[])),In,NumList path,_,_)) 
     => let val d = string_to_direction lr
           val index = case d of LEFT => 1 | RIGHT => 2
       in PRIM (WFERT_EQ(path@[index,1],d))
       end
  | (WeakFertilize  (* fertilizing an "<=>" goal *)
         (Method(MethodApp(lr,[])),Connective "<=>",NumList path,_,_)) 
     => let val d = string_to_direction lr
           val index = case d of LEFT => 1 | RIGHT => 2
       in PRIM (WFERT_EQ(path@[index],d))
       end
  | (WeakFertilize (* fertilizing an "==>" goal *)
         (Method(MethodApp(lr,[])),Connective "=>",NumList path,_,_)) 
     => let val d = string_to_direction lr
           val index = case d of LEFT => 1 | RIGHT => 2
       in 
        PRIM(WFERT_IMP (path@[index], string_to_direction lr))
       end
  | (MethodApp("normal",
        [Method (MethodApp("univ_intro",_))]))  => PRIM MK_GEN_TAC
  | (MethodApp("normal",
        [Method (MethodApp("imply_intro",_))])) => PRIM MK_DISCH_TAC
  | (MethodApp("normal",
      [Method (MethodApp("conjunct_elim",_))])) => PRIM ASM_CONJ
  | (MethodApp("equal", _))                     => PRIM ASM_EQ
  | (MethodApp("idtac", _))                     => PRIM MK_ALL_TAC

  | (MethodApp("casesplit", [Method(Disjunction[notp,p])])) 
     => PRIM (CASE (mk_client_term ts p Clam2Client.client_bool))
  | (MethodApp("casesplit", 
      [Method(Disjunction _)])) => PRIM (HUH"cases>2")
  | (UseHypotheses(tms))
     => let val tms' =
               map (fn tm => mk_client_term ts tm Clam2Client.client_bool) tms
        in  PRIM (USE_HYPS tms')
        end
  | (MethodApp("bool_cases",[Method (MethodApp (s,[]))]))
     => let val name = Clam2Client.client_var_name ts s
        in  PRIM (BOOL_CASES (mk_client_var {Name = name,Ty = client_bool}))
        end
  | (TermCancel term)
     => PRIM (TERM_CANCEL (mk_client_term ts term client_bool))
  | (MethodApp("external_lemma", [Method(External(_, _,Thm s))])) 
                                => PRIM(PROVED r (client_fact_id s))
  | (MethodApp("external_decision",
         [Method(External(_, _, Conjecture (tms,tm)))]))
     => let fun mk_ct tm = mk_client_term ts tm Clam2Client.client_bool
        in  PRIM (CONJECTURE (classifier "decision") r
                     (mapfilter mk_ct tms,mk_ct tm))
        end
  | (MethodApp("external_decision", [Method(External(_, _,Thm s))])) 
                                => PRIM(PROVED r (client_fact_id s))
  (* The next case is unrobust with respect to the way the method works *)
  (* but it does the job for now.                                       *)
  | (MethodApp("external_wave",
               [_,_,Method(MethodThen
                              (External(_,_,Conjecture (tms,tm)),
                               MethodApp("wave",[dir,NumList path,m,_])))]))
     => let fun mk_ct tm = mk_client_term ts tm Clam2Client.client_bool
            fun vtacf (name,Fact.Proved (thm,_)) =
               let val (rty,_,lr) = wave_tac m
                   val vtac = case rty of Eq => ORW | Imp => ARW
               in  vtac (name, thm, path, string_to_direction lr)
               end
              | vtacf _ = FAIL "Unproven wave rule"
        in  PRIM (conjecture vtacf (classifier "wave") r
                     (mapfilter mk_ct tms,mk_ct tm))
        end
  | (MethodApp("external_wave",
               [_,_,Method(MethodThen
                              (External(_,_,Thm s),
                               MethodApp("wave",[dir,NumList path,m,_])))]))
     => let val name = client_fact_id s
            val thm = proved r name
            val (rty,_,lr) = wave_tac m
            val vtac = case rty of Eq => ORW | Imp => ARW
        in  PRIM (vtac (name, thm, path, string_to_direction lr))
        end
  | (MethodApp (id,_))          => PRIM (HUH id)
  | (Methods L)                 => mk_tacl L
  | (MethodThen(_, Methods[]))  => raise tac_err("mk_tac","MethodThen")
  | (MethodThen(m1,Methods[m])) => THEN(mk_tac m1, mk_tac m)
  | (MethodThen(m1,Methods L))  => THENL(mk_tac m1,map mk_tac L)
  | (MethodThen(m1,m2))         => THEN(mk_tac m1, mk_tac m2)
  |        x                    => PRIM (HUH"<not handled yet>")
and (* We should never get to these. *)
    mk_arg_tac ts (Method method) = mk_tac method
  | mk_arg_tac ts In              = PRIM (HUH"In")
  | mk_arg_tac ts (NumList L)     = PRIM (HUH"NumList")
  | mk_arg_tac _ _                = PRIM (HUH"<not handled yet>")
and 
    mk_tacl L = end_itlist (curry THEN) (map mk_tac L)
in
  (mk_tac X, !r)
end 
end (* make_tac *)


(*---------------------------------------------------------------------------
 * Top level function. Build a tactic from a plan.
 *---------------------------------------------------------------------------*)
fun tactic_of classifier (ClamAST.Plan(goal,_,_,method,_),translations)
   = make_tac classifier translations method;

end; (* structure PlanToTactic *)
