(* ======================================================================== *)
(* Connection between RAHD/MetiTarski and SMT solvers (Z3 only, currently)  *)
(*                                                                          *)
(*                version 0.00a, last updated 28-Feb-2012                   *)
(*                                                                          *)
(* by G.O.Passmore, Cambridge Computer Laboratory and LFCS, Edinburgh, 2012 *)
(* Contact: (e) grant.passmore@cl.cam.ac.uk   (w) www.cl.cam.ac.uk/~gp351/. *)
(* ======================================================================== *)

structure SMT :> SMT =
struct

open Useful;
open Common;

val smt_used = ref false;

(* Z3 model history:
   We keep two, one for exact rational models, -and-
                one for real algebraic numbers (we store float approx's). *)

val model_history_rat = ref [] : (string * Rat.rat) list list ref;
val model_history_float = ref [] : (string * real) list list ref;

(* Use a model history? *)

val use_model_history = ref true;

(* ------------------------------------------------------------------------- *)
(* Printing of RCF formulas in SMT-Lib 2.0 format.                           *)
(* ------------------------------------------------------------------------- *)

fun smt_print_rat (q : Rat.rat) =
    let fun smt_print_int i = 
	    if (i < 0) then 
		("(- " ^ (Int.toString (abs i)) ^ ".)") 
	    else ((Int.toString i) ^ ".") in
	let val (a, b) = Rat.quotient_of_rat q in
	    if (b = 1) then smt_print_int a else
	    ("(/ " ^ (smt_print_int a) ^ " " ^ (smt_print_int b) ^ ")")
	end
    end;

fun smt_print_term t =
  case t of
    Term.Var x => Common.varname x
  | Term.Rat r => smt_print_rat r
  | Term.Fn (s,[]) => Common.no_underscores s
  | Term.Fn("-", [tm]) => "(- " ^ (smt_print_term tm) ^ ")"
  | Term.Fn("^", [tm1, tm2]) => "(^ " ^ (smt_print_term tm1) ^ " " ^ (smt_print_term tm2) ^ ")"
  | Term.Fn("*", [tm1, tm2]) => "(* " ^ (smt_print_term tm1) ^ " " ^ (smt_print_term tm2) ^ ")"
  | Term.Fn("-", [tm1, tm2]) => "(- " ^ (smt_print_term tm1) ^ " " ^ (smt_print_term tm2) ^ ")"
  | Term.Fn("+", [tm1, tm2]) => "(+ " ^ (smt_print_term tm1) ^ " " ^ (smt_print_term tm2) ^ ")"
  | Term.Fn(a,_) => raise Useful.Error ("smt_print_term: no match for " ^ a);

fun smt_print_atom ((reln, [x, y]) : Atom.atom) =
    "(" ^ reln ^ " " ^ (smt_print_term x) ^ " " ^ (smt_print_term y) ^ ")"
  | smt_print_atom _ = raise Useful.Error "smt_print_atom: atom's reln must be binary";

local open Formula
in
  fun smt_print_fml False = "false"
    | smt_print_fml True  = "true"
    | smt_print_fml (Atom a) = smt_print_atom a
    | smt_print_fml (Not p)  = "(not " ^ (smt_print_fml p) ^ ")"
    | smt_print_fml (And(p,q)) = "(and " ^ (smt_print_fml p) ^ " " ^ (smt_print_fml q) ^ ")"
    | smt_print_fml (Or(p,q))  = "(or " ^ (smt_print_fml p) ^ " " ^ (smt_print_fml q) ^ ")"
    | smt_print_fml (Imp(p,q)) = "(=> " ^ (smt_print_fml p) ^ " " ^ (smt_print_fml q) ^ ")"
    | smt_print_fml (Iff(p,q)) = "(<=> " ^ (smt_print_fml p) ^ " " ^ (smt_print_fml q) ^ ")"
    | smt_print_fml (Forall(x,p)) = smt_qquant "forall" (x,p)
    | smt_print_fml (Exists(x,p)) = smt_qquant "exists" (x,p) 
  and smt_qquant qname (x,p) = "(" ^ qname ^ " ((" ^ (Name.toString x) ^ " Real)) " ^ (smt_print_fml p) ^ ")"
end;

fun smt_print_fml_with_consts xvars fml =
    (String.concatWith "\n" (map (fn v => "(declare-fun " ^ v ^ " () Real)") xvars))
    ^ "\n" ^ "(assert " ^ (smt_print_fml fml) ^ ")";

fun smt_print_fml_with_consts_oneline xvars fml =
    (String.concatWith " " (map (fn v => "(declare-fun " ^ v ^ " () Real)") xvars))
    ^ " " ^ "(assert " ^ (smt_print_fml fml) ^ ")";
(* ------------------------------------------------------------------------- *)
(* SMT (Z3_nonlin) process I/O machinery                                     *)
(* ------------------------------------------------------------------------- *)

val smt_proc = ref (NONE : ((TextIO.instream, TextIO.outstream) Unix.proc 
			    * TextIO.instream * TextIO.outstream) option);

fun string_of_signal s =
    if s = Posix.Signal.hup then "Hangup"
    else if s = Posix.Signal.ill then "Illegal instruction"
    else if s = Posix.Signal.int then "Interrupt"
    else if s = Posix.Signal.kill then "Kill"
    else if s = Posix.Signal.segv then "Segmentation violation"
    else if s = Posix.Signal.bus then "Bus error"
    else "signalled " ^ SysWord.toString (Posix.Signal.toWord s);
    
local open Unix
in
fun stringOfStatus W_EXITED = "exited"
  | stringOfStatus (W_EXITSTATUS w) = "exit " ^ Word8.toString w
  | stringOfStatus (W_SIGNALED s) = string_of_signal s
  | stringOfStatus (W_STOPPED w) = "stopped";
end;

(* Signal: subprocess cpu time limit exceeded *)

val sigxcpu = Posix.Signal.fromWord (SysWord.fromInt 24);

(* Close SMT process *)

fun smt_close ignore_outcome = 
    case !smt_proc of
	SOME (proc, instr, outstr) =>
	let
	    val _ = Unix.kill (proc, Posix.Signal.kill)
	    val status = Unix.fromStatus (Unix.reap proc)
	in 
	    (if ignore_outcome orelse Useful.mem status [Unix.W_EXITED, Unix.W_SIGNALED 9] then () 
	     else if status = Unix.W_SIGNALED sigxcpu
	     then print "Processor time limit exceeded for SMT solver\n"
	     else print ("****ERROR: exit status = " ^ stringOfStatus status ^ "\n");
	     smt_proc := NONE)
	end
      | NONE => ();

(* In-stream and Out-stream of smt_proc object *)

fun smt_is () = 
    case !smt_proc of
	SOME (_, y, _) => SOME y
      | NONE => NONE;

fun smt_os () = 
    case !smt_proc of
	SOME (_, _, z) => SOME z
      | NONE => NONE;

fun stream_strings_until_prompt is prompt_str acc =
    case smt_is() of 
	SOME is => (case TextIO.inputLine is of
			NONE => raise Useful.Error "SMT solver has unexpectedly terminated"
		      | SOME line =>
			(Useful.chatting 3 andalso Useful.chat ("SMT: " ^ line);
			 if String.isSubstring prompt_str line 
			 then List.rev acc
			 else stream_strings_until_prompt is prompt_str (line :: acc)))
      | NONE => raise Useful.Error "No SMT process.";

fun stream_strings_until_prefix is prefix_str acc =
    case smt_is() of 
	SOME is => (case TextIO.inputLine is of
			NONE => raise Useful.Error "SMT solver has unexpectedly terminated"
		      | SOME line =>
			(Useful.chatting 3 andalso Useful.chat ("SMT: " ^ line);
			 if String.isPrefix prefix_str line 
			 then List.rev acc
			 else stream_strings_until_prefix is prefix_str (line :: acc)))
      | NONE => raise Useful.Error "No SMT process.";


fun smt_writeln (s) = 
    case smt_os() of 
	SOME os => TextIO.output (os, (s ^ "\n"))
      | NONE => raise Useful.Error "No SMT process.";

(* Z3 config string for creating models *)

val z3_produce_models_str = 
    "(set-option :produce-models true)(set-option :pp-decimal true)"
    ^ "(set-option :pp-decimal-precision 20)";

(* Open an SMT process and setup the global smt_proc with I/O streams *)

fun smt_open() = case !smt_proc of
    SOME pio => pio
  | NONE =>
    let val smt_bin_str = case OS.Process.getEnv"Z3_NONLIN" of
			      NONE => Useful.die ("Environment variable Z3_NONLIN must " ^ 
						  "point to Z3 nonlinear (nlsat) binary")
			    | SOME s => s
	val proc = Unix.execute(smt_bin_str, [(* "NLSAT=true", *)"-si", "-smt2"])
	val (instr, outstr) = Unix.streamsOf proc
    in 
	smt_used := true;
      	smt_proc := SOME (proc, instr, outstr);
	smt_writeln z3_produce_models_str;
	(proc, instr, outstr)
    end;

fun first_substring [] line = NONE
  | first_substring (s::ss) line = 
      if String.isSubstring s line then SOME s 
      else first_substring ss line;

fun strings_in_stream ss is =
   case TextIO.inputLine is of
       NONE => raise Error "SMT solver has unexpectedly terminated"
     | SOME line => 
        (chatting 3 andalso chat ("SMT: " ^ line);
         if String.isSubstring "Error" line 
         then raise Bug ("SMT ERROR MSG: " ^ line) else ();
         case first_substring ss line of
             SOME s => s
           | NONE => strings_in_stream ss is);  (*keep looking*)

fun strings_in_stream_until_prefix' p ls is =
   case TextIO.inputLine is of
       NONE => raise Error "SMT solver has unexpectedly terminated"
     | SOME line => 
        (chatting 3 andalso chat ("SMT: " ^ line);
         case String.isPrefix p line of
             true => List.rev (line :: ls)
           | false => strings_in_stream_until_prefix' p (line :: ls) is);  (*keep looking*)

fun strings_in_stream_until_prefix p = 
    strings_in_stream_until_prefix' 
	p [] (case smt_is() of SOME is => is 
			     | NONE => raise Useful.Error "No input stream for SMT.");

(* A parser for Z3 models, using the method of recursive descent.

   Example Z3 model string:
    (model (define-fun x () Real 0.5)
           (define-fun y () Real 12.91?)
           (define-fun z () Real 390.01))

   Result:

    > z3_parse_model "(model (define-fun x () Real 0.5) 
                      (define-fun y () Real 12.91?) 
                      (define-fun z () Real 390.01))";

    val it = [("x", 0.5, true), ("y", 12.91, false), ("z", 390.01, true)]:
      (string * real * bool) list

  We return a list of type (string * real * bool) list, consisting of
  triples of the form (v, f, e?) where v is a variable string, f is a
  float/real, and e? is true iff Z3 says the sample-pt is exact (not
  flagged with a `?' in Z3's printout of the model.)  Note that
  rounding is not an issue here w.r.t. soundness, as these models are
  only ever used to accelerate the recognition of SAT for RCF
  formulas. They are never used to prove that anything is UNSAT. *)

exception Z3_PARSE_MODEL of string;

fun z3_parse_model m_str =
    let fun parse_funs l =
	    case l of
		("(" :: r) =>
		(case parse_funs r of
		     (f_lst, s) => (f_lst, s))
	      | _ => (case parse_fun l of
			  (f, r) =>
			  if (hd r) = "(" then
			      let val (f_lst, r) = parse_funs r in
				  ([f] @ f_lst, r)
			      end
			  else ([f], r))
	and parse_fun l =
	    case l of 
		[] => raise Z3_PARSE_MODEL "failure in parse_fun(1)"
	      | ("define" :: "-" :: "fun" :: name :: "(" :: ")" :: "Real" :: h :: "." :: t :: ")" :: rst) 
		=> let val new_fn = (name, 
				     (case (Real.fromString (String.concat [h, ".", t]))
				       of SOME r => r
					| NONE => raise Z3_PARSE_MODEL "failure in parse_fun(2)"),
				     true)
		   in (new_fn, rst) end
	      | ("define" :: "-" :: "fun" :: name :: "(" :: ")" :: "Real" :: h :: "." :: t :: "?" :: ")" :: rst) 
		=> let val new_fn = (name, 
				     (case (Real.fromString (String.concat [h, ".", t]))
				       of SOME r => r
					| NONE => raise Z3_PARSE_MODEL "failure in parse_fun(3)"),
				     false) (* <-- `?' indicates this witness is inexact. *)
		   in (new_fn, rst) end
	      | ("define" :: "-" :: "fun" :: name :: "(" :: ")" :: "Real" :: "(" :: "-" :: h :: "." :: t :: ")" :: ")" :: rst) 
		=> let val new_fn = (name, 
				     (case (Real.fromString (String.concat [h, ".", t]))
				       of SOME r => ~r
					| NONE => raise Z3_PARSE_MODEL "failure in parse_fun(2)"),
				     true)
		   in (new_fn, rst) end
	      | ("define" :: "-" :: "fun" :: name :: "(" :: ")" :: "Real" :: "(" :: "-" :: h :: "." :: t :: "?" :: ")" :: ")" :: rst) 
		=> let val new_fn = (name, 
				     (case (Real.fromString (String.concat [h, ".", t]))
				       of SOME r => ~r
					| NONE => raise Z3_PARSE_MODEL "failure in parse_fun(3)"),
				     false) (* <-- `?' indicates this witness is inexact. *)
		   in (new_fn, rst) end

	      | _ => raise Z3_PARSE_MODEL "failure in parse_fun(4)"
	fun parse_wrapper l =
	    case l of 
		[] => raise Z3_PARSE_MODEL "failure in parse_wrapper(1)"
	      | ("(" :: "model" :: r) =>
		(case parse_funs r of
		     (model_lst, ")" :: s) => (model_lst, s)
		   | _ => raise Z3_PARSE_MODEL "failure in parse_wrapper(2)")
	      | _ => raise Z3_PARSE_MODEL "failure in parse_wrapper(3)"
    in case parse_wrapper (Common.lex (explode m_str)) 
	of (m, _) => m end;

(* Read a model from a current Z3 process which has deduced its current
   context to be SAT. *)

fun z3_get_model () =
    (smt_writeln "(get-model)";
     let val model_str = String.concat (strings_in_stream_until_prefix ")\n")
	 val model = z3_parse_model model_str 
     in
	 (* print ("\n\nModel! String: " ^ model_str ^ ".\n");
	 print ("Parsed model: "); 
	 map (fn (v, p, e) => (print (v ^ " |-> " ^ (Real.toString p) ^ "   exact: " ^ (Bool.toString e))))
	     model; *)
	 model
     end);

(* Get a model and add it to the model history. *)

fun z3_process_model () =
    if (!use_model_history) then
	let val m = z3_get_model()
	    fun rat_pt (_, _, e) = e
	    val rat_model = List.all rat_pt m
	    fun make_rat_pt (v, p, e) = (v, rat_of_float p)
	    fun make_rat_model model = map make_rat_pt model
	    fun make_float_model model = map (fn (v, p, e) => (v, p)) model
	in if rat_model then 
	       model_history_rat := (make_rat_model m) :: (!model_history_rat) 
	   else model_history_float := (make_float_model m) :: (!model_history_float)
	end
    else ();

(* ------------------------------------------------------------------------- *)
(* SMT (Z3_nonlin) SAT/UNSAT decision function (only for Exists formulas)    *)
(* ------------------------------------------------------------------------- *)
(* Returns true (for success) iff SMT solver (Z3_nonlin) returns UNSAT       *)

(* String for invoking Z3's NLSAT tactic with requisite pre-processing. *)

val nlsat_str = 
    "(and-then simplify purify-arith propagate-values elim-term-ite"
    ^ " solve-eqs tseitin-cnf simplify nlsat)";

(* String for invoking Z3's NLSAT tactic with requisite pre-processing,
   but with factorisation disabled. *)

val nlsat_no_factor_str = 
    "(and-then simplify purify-arith propagate-values elim-term-ite"
    ^ " solve-eqs tseitin-cnf simplify (using-params nlsat :factor false :algebraic-min-mag 256))";

val nlsat_var_shuffle_str = 
    "(check-sat-using (and-then simplify purify-arith propagate-values" 
    ^ " elim-term-ite solve-eqs tseitin-cnf simplify (using-params nlsat :shuffle-vars true :seed 13)))";

val nlsat_factor_before_str = 
    "(and-then simplify purify-arith propagate-values elim-term-ite"
    ^ " solve-eqs tseitin-cnf factor simplify nlsat)";

fun smt_unsat xvars Formula.False = true
  | smt_unsat xvars Formula.True = false
  | smt_unsat [] _ = false    (*no variables, so abandon*)
  | smt_unsat xvars fm = 
   let
     val varlist = string_tuple xvars
     val fml_str = smt_print_fml_with_consts xvars fm
     val _ = chatting 2 andalso chat ("----- Calling Z3 on\n" ^ fml_str ^ "\n-----")
     val (proc, from_smt, to_smt) = smt_open()
     val _ = smt_writeln fml_str
     val check_str = if (length xvars = 1) then nlsat_no_factor_str else nlsat_factor_before_str
     val _ = smt_writeln ("(check-sat-using " ^ check_str ^ ")")
     val result = (strings_in_stream ["failed", "unsat","sat","unknown","error"] from_smt)
     val _ = chatting 2 andalso chat ("----- Z3 result: " ^ result ^ "\n")
     val _ = if (result = "sat") then z3_process_model() else ();
     val _ = smt_writeln "(reset)"
   in
      let val res = (result = "unsat")
      in res end
   end;

(* ------------------------------------------------------------------------- *)
(* SMT (Z3_nonlin) SAT/UNSAT decision function (only for Exists formulas),   *)
(*  with user-controllable proof strategies.                                 *)
(* ------------------------------------------------------------------------- *)
(* Returns true (for success) iff SMT solver (Z3_nonlin) returns UNSAT       *)

fun smt_unsat_with_strategy xvars Formula.False _ _ = true
  | smt_unsat_with_strategy xvars Formula.True _ _ = false
  | smt_unsat_with_strategy [] _ _ _ = false     (* no variables, so abandon *)
  | smt_unsat_with_strategy xvars fm init_str strategy_str = 
   let
     val varlist = string_tuple xvars
     val fml_str = smt_print_fml_with_consts xvars fm
     val _ = chatting 2 andalso chat ("----- Calling Z3 on\n" ^ fml_str ^ "\n-----")
     val (proc, from_smt, to_smt) = smt_open()
     val _ = if not(init_str = "") then smt_writeln init_str else ()
     val _ = smt_writeln fml_str
     val _ = smt_writeln strategy_str
     val result = (strings_in_stream ["failed", "unsat","sat","unknown","error"] from_smt)
     val _ = chatting 2 andalso chat ("----- Z3 result: " ^ result ^ "\n")
     val _ = if (result = "sat") then z3_process_model() else ();
     val _ = smt_writeln "(reset)"
   in
      let val res = (result = "unsat")
      in 
	  (* smt_close true; *) 
	  res
      end
   end;

(* ------------------------------------------------------------------------- *)
(* SMT (Z3_nonlin) Judgment decision function (only for Exists formulas),    *)
(*  with user-controllable proof strategies.                                 *)
(* ------------------------------------------------------------------------- *)
(* Returns Common.tv       *)

fun smt_judgment_with_strategy xvars Formula.False _ _ = Common.Unsat
  | smt_judgment_with_strategy xvars Formula.True _ _ = Common.Sat NONE
  | smt_judgment_with_strategy [] _ _ _ = Common.Sat NONE     (* no variables, so abandon *)
  | smt_judgment_with_strategy xvars fm init_str strategy_str = 
   let
     val varlist = string_tuple xvars
     val fml_str = smt_print_fml_with_consts xvars fm
     val _ = chatting 2 andalso chat ("----- Calling Z3 on\n" ^ fml_str ^ "\n-----")
     val (proc, from_smt, to_smt) = smt_open()
     val _ = if not(init_str = "") then smt_writeln init_str else ()
     val _ = smt_writeln fml_str
     val _ = smt_writeln strategy_str
     val result = (strings_in_stream ["failed", "unsat","sat","unknown","error"] from_smt)
     val _ = chatting 2 andalso chat ("----- Z3 result: " ^ result ^ "\n")
     val _ = if (result = "sat") then z3_process_model() else ();
     val _ = smt_writeln "(reset)"
   in
       if (result = "unsat") then Common.Unsat
       else if (result = "sat") then Common.Sat NONE
       else Common.Unknown
   end;

(* Some more Z3 strategies.
   First, their strings. *)

(* Stop after a conflict is detected. *)

val nlsat_no_conflict_str 
  = "(check-sat-using (and-then "
    ^ "simplify "
    ^ "purify-arith "
    ^ "propagate-values "
    ^ "elim-term-ite "
    ^ "solve-eqs "
    ^ "tseitin-cnf "
    ^ "(using-params nlsat :max-conflicts 2)))";

(* Stop after a conflict is detected, 
   no factorisation. *)

val nlsat_no_conflict_no_factor_str 
  = "(check-sat-using (and-then "
    ^ "simplify "
    ^ "purify-arith "
    ^ "propagate-values "
    ^ "elim-term-ite "
    ^ "solve-eqs "
    ^ "tseitin-cnf "
    ^ "(using-params nlsat :max-conflicts 1 :factor false)))";
 
val nlsat_random_var_ord_str
  = "(check-sat-using (and-then "
    ^ "simplify "
    ^ "purify-arith "
    ^ "propagate-values "
    ^ "elim-term-ite "
    ^ "solve-eqs "
    ^ "tseitin-cnf "
    ^ "(using-params nlsat :shuffle-vars true :seed 2)))";

fun z3_linear_relax xvars fm timelimit =
 smt_unsat_with_strategy
     xvars 
     fm 
     ("(set-option :nl-arith-gb false)\n" ^
      "(set-option :nl-arith false)\n" ^
      "(set-option :nlsat false)")
     ("(check-sat-using (and-then simplify smt) :timeout " 
      ^ (Int.toString timelimit) ^ ")");

fun z3_nlsat xvars fm timelimit =
 smt_unsat_with_strategy
     xvars 
     fm 
     ""
     ("(check-sat-using " ^ (if (length xvars = 1) then nlsat_no_factor_str else nlsat_factor_before_str) ^ " :timeout " 
      ^ (Int.toString timelimit) ^ ")");

fun z3_nlsat_no_factor xvars fm timelimit =
 smt_unsat_with_strategy
     xvars 
     fm 
     ""
     ("(check-sat-using " ^ nlsat_no_factor_str ^ " :timeout " 
      ^ (Int.toString timelimit) ^ ")");

fun z3_nl_arith xvars fm timelimit =
 smt_unsat_with_strategy
     xvars 
     fm 
     ("(set-option :nl-arith-gb false)\n" ^
      "(set-option :nl-arith true)")
     ("(check-sat-using (and-then simplify smt) :timeout " 
      ^ (Int.toString timelimit) ^ ")");

fun z3_nl_arith_gb xvars fm timelimit =
 smt_unsat_with_strategy
     xvars 
     fm 
     ("(set-option :nl-arith-gb true)\n" ^
      "(set-option :nl-arith false)\n")
     ("(check-sat-using (and-then simplify smt) :timeout " 
      ^ (Int.toString timelimit) ^ ")");

(* Try Z3 with no conflicts, returning a Common.tv *)

fun z3_no_conflicts_judgment xvars fm =
    smt_judgment_with_strategy 
    xvars
    fm 
    ""
    nlsat_no_conflict_str;

(* Try Z3 with no conflicts, no factorisation, 
   returning a Common.tv *)

fun z3_no_conflicts_no_factor_judgment xvars fm =
    smt_judgment_with_strategy 
    xvars
    fm 
    ""
    nlsat_no_conflict_no_factor_str;


end;
