File ‹Tools/Mirabelle/mirabelle_sledgehammer_filter.ML›
structure Mirabelle_Sledgehammer_Filter: MIRABELLE_ACTION =
struct
fun get args name default_value =
case AList.lookup (op =) args name of
SOME value => Value.parse_real value
| NONE => default_value
fun extract_relevance_fudge args
{local_const_multiplier, worse_irrel_freq, higher_order_irrel_weight, abs_rel_weight,
abs_irrel_weight, theory_const_rel_weight, theory_const_irrel_weight,
chained_const_irrel_weight, intro_bonus, elim_bonus, simp_bonus, local_bonus, assum_bonus,
chained_bonus, max_imperfect, max_imperfect_exp, threshold_divisor, ridiculous_threshold} =
{local_const_multiplier = get args "local_const_multiplier" local_const_multiplier,
worse_irrel_freq = get args "worse_irrel_freq" worse_irrel_freq,
higher_order_irrel_weight = get args "higher_order_irrel_weight" higher_order_irrel_weight,
abs_rel_weight = get args "abs_rel_weight" abs_rel_weight,
abs_irrel_weight = get args "abs_irrel_weight" abs_irrel_weight,
theory_const_rel_weight = get args "theory_const_rel_weight" theory_const_rel_weight,
theory_const_irrel_weight = get args "theory_const_irrel_weight" theory_const_irrel_weight,
chained_const_irrel_weight = get args "chained_const_irrel_weight" chained_const_irrel_weight,
intro_bonus = get args "intro_bonus" intro_bonus,
elim_bonus = get args "elim_bonus" elim_bonus,
simp_bonus = get args "simp_bonus" simp_bonus,
local_bonus = get args "local_bonus" local_bonus,
assum_bonus = get args "assum_bonus" assum_bonus,
chained_bonus = get args "chained_bonus" chained_bonus,
max_imperfect = get args "max_imperfect" max_imperfect,
max_imperfect_exp = get args "max_imperfect_exp" max_imperfect_exp,
threshold_divisor = get args "threshold_divisor" threshold_divisor,
ridiculous_threshold = get args "ridiculous_threshold" ridiculous_threshold}
structure Prooftab =
Table(type key = int * int val ord = prod_ord int_ord int_ord)
fun print_int x = Value.print_int (Synchronized.value x)
fun percentage a b = if b = 0 then "N/A" else Value.print_int (a * 100 div b)
fun percentage_alt a b = percentage a (a + b)
val default_prover = ATP_Proof.eN
fun with_index (i, s) = s ^ "@" ^ Value.print_int i
val proof_fileK = "proof_file"
fun make_action ({arguments, ...} : Mirabelle.action_context) =
let
val (proof_table, args) =
let
val (pf_args, other_args) =
List.partition (curry (op =) proof_fileK o fst) arguments
val proof_file =
(case pf_args of
[] => error "No \"proof_file\" specified"
| (_, s) :: _ => s)
fun do_line line =
(case space_explode ":" line of
[line_num, offset, proof] =>
SOME (apply2 (the o Int.fromString) (line_num, offset),
proof |> space_explode " " |> filter_out (curry (op =) ""))
| _ => NONE)
val proof_table =
File.read (Path.explode proof_file)
|> space_explode "\n"
|> map_filter do_line
|> AList.coalesce (op =)
|> Prooftab.make
in (proof_table, other_args) end
val num_successes = Synchronized.var "num_successes" 0
val num_failures = Synchronized.var "num_failures" 0
val num_found_proofs = Synchronized.var "num_found_proofs" 0
val num_lost_proofs = Synchronized.var "num_lost_proofs" 0
val num_found_facts = Synchronized.var "num_found_facts" 0
val num_lost_facts = Synchronized.var "num_lost_facts" 0
fun run ({pos, pre, ...} : Mirabelle.command) =
let
val results =
(case (Position.line_of pos, Position.offset_of pos) of
(SOME line_num, SOME offset) =>
(case Prooftab.lookup proof_table (line_num, offset) of
SOME proofs =>
let
val thy = Proof.theory_of pre
val {context = ctxt, facts = chained_ths, goal} = Proof.goal pre
val prover = AList.lookup (op =) args "prover" |> the_default default_prover
val params as {max_facts, ...} = Sledgehammer_Commands.default_params thy args
val default_max_facts = 256
val relevance_fudge =
extract_relevance_fudge args Sledgehammer_MePo.default_relevance_fudge
val subgoal = 1
val (_, hyp_ts, concl_t) = ATP_Util.strip_subgoal goal subgoal ctxt
val keywords = Thy_Header.get_keywords' ctxt
val css_table = Sledgehammer_Fact.clasimpset_rule_table_of ctxt
val facts =
Sledgehammer_Fact.nearly_all_facts ctxt false
Sledgehammer_Fact.no_fact_override keywords css_table chained_ths
hyp_ts concl_t
|> Sledgehammer_Fact.drop_duplicate_facts
|> Sledgehammer_MePo.mepo_suggested_facts ctxt params
(the_default default_max_facts max_facts)
(SOME relevance_fudge) hyp_ts concl_t
|> map (fst o fst)
val (found_facts, lost_facts) =
flat proofs |> sort_distinct string_ord
|> map (fn fact => (find_index (curry (op =) fact) facts, fact))
|> List.partition (curry (op <=) 0 o fst)
|>> sort (prod_ord int_ord string_ord) ||> map snd
val found_proofs = filter (forall (member (op =) facts)) proofs
val n = length found_proofs
val _ = Int.div
val _ = Synchronized.change num_failures (curry op+ 1)
val log1 =
if n = 0 then
(Synchronized.change num_failures (curry op+ 1); "Failure")
else
(Synchronized.change num_successes (curry op+ 1);
Synchronized.change num_found_proofs (curry op+ n);
"Success (" ^ Value.print_int n ^ " of " ^
Value.print_int (length proofs) ^ " proofs)")
val _ = Synchronized.change num_lost_proofs (curry op+ (length proofs - n))
val _ = Synchronized.change num_found_facts (curry op+ (length found_facts))
val _ = Synchronized.change num_lost_facts (curry op+ (length lost_facts))
val log2 =
if null found_facts then
""
else
let
val found_weight =
Real.fromInt (fold (fn (n, _) => Integer.add (n * n)) found_facts 0)
/ Real.fromInt (length found_facts)
|> Math.sqrt |> Real.ceil
in
"Found facts (among " ^ Value.print_int (length facts) ^
", weight " ^ Value.print_int found_weight ^ "): " ^
commas (map with_index found_facts)
end
val log3 =
if null lost_facts then
""
else
"Lost facts (among " ^ Value.print_int (length facts) ^ "): " ^
commas lost_facts
in cat_lines [log1, log2, log3] end
| NONE => "No known proof")
| _ => "")
in
results
end
fun finalize () =
if Synchronized.value num_successes + Synchronized.value num_failures > 0 then
"\nNumber of overall successes: " ^ print_int num_successes ^
"\nNumber of overall failures: " ^ print_int num_failures ^
"\nOverall success rate: " ^
percentage_alt (Synchronized.value num_successes)
(Synchronized.value num_failures) ^ "%" ^
"\nNumber of found proofs: " ^ print_int num_found_proofs ^
"\nNumber of lost proofs: " ^ print_int num_lost_proofs ^
"\nProof found rate: " ^
percentage_alt (Synchronized.value num_found_proofs)
(Synchronized.value num_lost_proofs) ^ "%" ^
"\nNumber of found facts: " ^ print_int num_found_facts ^
"\nNumber of lost facts: " ^ print_int num_lost_facts ^
"\nFact found rate: " ^
percentage_alt (Synchronized.value num_found_facts)
(Synchronized.value num_lost_facts) ^ "%"
else
""
in ("", {run = run, finalize = finalize}) end
val () = Mirabelle.register_action "sledgehammer_filter" make_action
end