(* =========================================================================== *)
(*         Arithmetic in the (Ordered) Field of Real Algebraic Numbers         *)
(*                                                                             *)
(* by G.O.Passmore, Aesthetic Integration Ltd & Clare Hall, Univ. of Cambridge *)
(* Contact:   (e) grant.passmore@cl.cam.ac.uk    (w) www.cl.cam.ac.uk/~gp351/. *)
(* =========================================================================== *)

structure RealAlg : RealAlg =
struct

open Algebra;
open Groebner;
open Resultant;
open Sturm;

(* Exception: Described real algebraic number is not unique. *)

exception Non_unique of string;

(* Interval representation of a real algebraic number *)

type real_alg = Algebra.poly * Rat.rat * Rat.rat;

(* Polynomial `x' (actually, `x0') *)

val p_x = [(Rat.one, [(0, 1)])] : poly;

(* A unique real_alg representation of zero: (x, 0, 0). *)

val zero = (p_x, Rat.zero, Rat.zero) : real_alg;

(* Given a poly p in Q[x], and l, u in Q, construct a real algebraic
   number representation of the form (p, l, u). We normalize the
   representation to ensure (i) p is square-free, (ii) p truly has
   exactly one real root in [l, u], and (iii) that [l, u] contains
   zero iff (p, l, u) represents 0.  In this final case, we return
   RealAlg.zero, so that a constructed real_alg value r represents
   zero iff r = RealAlg.zero. *)

fun mk_real_alg p l u =
    let val p = Sturm.square_free p
        val p = Algebra.p_monic p
        val k = Sturm.num_roots_in_cl_intvl p (l, u) in
        if k<>1 then
            raise Non_unique ("Poly (" ^ (p_toString p) ^ ") " ^
                              "does not have exactly one root in " ^
                              "the interval [" ^ (Rat.toString l) ^ ", " ^
                              (Rat.toString u) ^ "].")
        else
            if (Rat.le(l, Rat.zero) andalso Rat.le(Rat.zero, u))
            then
                (* [l, u] contains 0 *)
                case (p_div p [p_x]) of
                    (_, m) =>
                    if m = p_zero then
                        (* (p, l, u) represents 0 *)
                        zero
                    else
                        (* Normalize the interval so that it doesn't
                           contain zero.  We use a simple (O(n^2))
                           scheme based on Cauchy's and Landau's
                           bounds on non-zero roots. Note that we
                           could also do this iteratively via interval
                           refinement, avoiding the use of these root
                           bounds. *)

                        let val d = Rat.add(Rat.one, Algebra.p_sum_abs_coeffs p)
                            (* e = 1/(1 + Sum_{i=0}^n a_i) *)
                            val e = Rat.mult(Rat.one, Rat.inv d)
                            val e' = Rat.neg e
                            val f = (fn q => fn p =>
                                        Sturm.sign_of_rat(p_lc(p_eval(p, 0, q))))
                            val sc = Sturm.sturm_chain p
                            val sc_l = map (f l) sc
                            val sc_e' = map (f e') sc
                            val n_l = Sturm.num_sign_changes sc_l
                            val n_e' = Sturm.num_sign_changes sc_e'
                        in
                            if n_l > n_e' then
                                (* negative root in [l, e'] *)
                                (p, l, e')
                            else
                                (* positive root in [e, u] *)
                                (p, e, u)
                        end
            else
                (* Since we're not representing 0, we'll ensure the
                   polynomial `x' is not a factor. *)
                let val (fs, m) = Groebner.p_div p [p_x]
                in
                    if m = p_zero then
                        (Array.sub(fs, 0), l, u)
                    else
                        (p, l, u) : real_alg
                end
    end;

(* Compute an interval representation of a real algebraic number.
   Currently, this is just the identity function. But, as we support
   other notations for real_alg's, this will get more elaborate, e.g.,
   converting between positional and interval representation. *)

fun intvl_rep (r : real_alg) = r;

(* Given two real algebraic numbers alpha, beta, a polynomial f, and
   an arithmetic operation on intervals intvl_op, continue refining
   the intervals for alpha and beta until f has precisely one root in
   intvl_op(intvl(alpha), intvl(beta)). *)

fun refine_with_op (alpha, beta : real_alg, f : poly, intvl_op) =
    let val f = Sturm.square_free f
        val (p1, l1, u1) = alpha
        val (p2, l2, u2) = beta
        val cur_intvl = intvl_op((l1, u1), (l2, u2))
        val k = Sturm.num_roots_in_cl_intvl f cur_intvl
    in
        if k<>1 then
            let val (l1', u1') = Sturm.refine_root p1 (l1, u1)
                val (l2', u2') = Sturm.refine_root p2 (l2, u2)
            in
                refine_with_op((p1, l1', u1'),
                               (p2, l2', u2'),
                               f,
                               intvl_op)
            end
        else
            cur_intvl
    end;

(* Given an interval representation for a real algebraic number alpha,
   compute a representation for -alpha. *)

fun neg alpha =
    if alpha = zero then zero
    else
        let val (p, l, u) = alpha
        in
            mk_real_alg (Algebra.p_subst_neg_x p) (Rat.neg u) (Rat.neg l)
        end;

(* Given interval representations for two real algebraic numbers alpha
   and beta, compute an interval representation for (alpha + beta). *)

fun add (alpha, beta) =
    if alpha = zero then beta
    else if beta = zero then alpha
    else
        let val (p1, _, _) = alpha
            val (p2, _, _) = beta
            val x_minus_y = [(Rat.one, [(0, 1)]),
                             (Rat.neg Rat.one, [(1, 1)])] : poly
            val y = [(Rat.one, [(1, 1)])] : poly
            val p1_x_minus_y = Algebra.p_subst(p1, 0, x_minus_y)
            val p2_y = Algebra.p_subst(p2, 0, y)
            val res = Resultant.biv_resultant p1_x_minus_y p2_y
            (* Note: res in Q[y]. We must convert to Q[x]. *)
            val f = mk_univ_in(res, 0)
            (* Now, f has (alpha + beta) as a root. But, we must still
               compute an interval [l, u] s.t. (alpha + beta) is f's only
               root in [l, u]. *)
            val (l, u) =
                refine_with_op(alpha,
                               beta,
                               f,
                               (fn ((l1, u1),
                                    (l2, u2)) =>
                                   (Rat.add(l1, l2),
                                    Rat.add(u1, u2))))
        in mk_real_alg f l u end;

(* Given interval representations for two real algebraic numbers alpha
   and beta, compute an interval representation for (alpha * beta). *)


(* Given a real algebraic number alpha and a polynomial g(x), compute
   the sign of g(alpha). *)

fun sign_at g alpha =
    let val (p, l, u) = alpha
        val p'g = Algebra.p_mult(Sturm.d_dx p, g)
        val sc = Sturm.gen_sturm_chain p p'g
        val f = (fn q => fn p =>
                    Sturm.sign_of_rat(p_lc(p_eval(p, 0, q))))
        val sc_l = map (f l) sc
        val sc_u = map (f u) sc
        val n = (Sturm.num_sign_changes sc_l) - (Sturm.num_sign_changes sc_u)
    in
        if n<0 then Sturm.NEG
        else if n=0 then Sturm.ZERO
        else Sturm.POS
    end;


(* Given a real algebraic number alpha, compute its sign. By virtue of
   the fact that our intervals only contain zero precisely when the
   real algebraic number is RealAlg.zero, we can read the sign
   directly from one side of the interval boundary. *)

fun sign alpha =
    if alpha = zero then Sturm.ZERO
    else
        let val (_, l, _) = alpha in
            if Rat.gt0 l then Sturm.POS
            else Sturm.NEG
        end;

(* Given a list of polynomials, compute all of their roots and
   represent them as real algebraic numbers. Note, we do not attempt
   to make this list minimal: it may contain the same real algebraic
   number more than once and represented in more than one way. *)

fun roots_of_polys ps =
    let fun roots_of_polys' ps rs =
            case ps of
                [] => rs
              | (p::ps) =>
                let val p' = Sturm.square_free p
                    val roots = Sturm.isolate_roots p'
                    val sps = map (fn (l, u) => mk_real_alg p' l u)
                                  roots
                in
                    roots_of_polys' ps (sps @ rs)
                end
    in roots_of_polys' ps [] end;

(* Midpoint of two rationals a, b *)

fun rat_mid a b =
    Rat.mult (Rat.add(a, b), Rat.inv (Rat.rat_of_int 2));

(* Given a rational, construct a real_alg representing it.  Note, we
   need to add a little wiggle room for the interval! *)

fun real_alg_of_rat q =
    let val p = [(Rat.one, [(0, 1)]), (Rat.neg q, [])]
    in mk_real_alg p (Rat.add(Rat.neg Rat.one, q)) (Rat.add(Rat.one, q)) end;

(* Construct a CAD of R^1 induced by polynomials ps *)

fun univ_cad_sample ps =
    let val P = foldl Algebra.p_mult p_one ps
        val P' = Algebra.p_monic (Sturm.square_free P)
        val roots = Sturm.isolate_roots P'
        val sps = map (fn (l, u) => mk_real_alg P' l u)
                      roots
        fun cmp ((_, _, u1), (_, l2, _)) = Rat.compare (u1, l2)
        val sps' = Useful.sort cmp sps
        fun clean_multiple_zs sps z_found clean_sps =
            case sps of [] => clean_sps
                      | alpha :: rst =>
                        if alpha = zero then
                            if z_found then
                                clean_multiple_zs rst z_found clean_sps
                            else
                                clean_multiple_zs rst true (alpha :: clean_sps)
                        else
                            clean_multiple_zs rst z_found (alpha :: clean_sps)
    in
        if sps' = [] then [zero]
        else
            let val (_, l, _) = (hd sps')
                val c_sps = ref [real_alg_of_rat (Rat.add(l, Rat.neg Rat.one))]
                fun sample_sectors alphas =
                    case alphas of [] => ()
                                 | [(p, l, u)] =>
                                   (c_sps := (real_alg_of_rat (Rat.add(u, Rat.one))) :: (!c_sps);
                                    c_sps := (p, l, u) :: (!c_sps))
                                 | ((p1, l1, u1) :: (p2, l2, u2) :: rst) =>
                                   (c_sps := (p1, l1, u1) :: (!c_sps);
                                    c_sps := (real_alg_of_rat (rat_mid u1 l2)) :: (!c_sps);
                                    (sample_sectors ((p2, l2, u2) :: rst)))
            in
                sample_sectors sps';
                rev (clean_multiple_zs (!c_sps) false [])
            end
    end;

(* String representation of a real algebraic number *)

fun toString (r : real_alg) =
    case r of
        (p, l, u) => ("RealRoot(" ^ (Algebra.p_toString p) ^
                      ", [" ^ (Rat.toString l) ^ ", " ^
                      (Rat.toString u) ^ "])");

(* Examples:

val p = [(Rat.one, [(0, 2)]), (Rat.rat_of_int ~2, [])] : poly;
p_toString p;
val it = "x0^2 + -2": string

val alpha = mk_real_alg p Rat.zero (Rat.rat_of_int 5);
val alpha =
  ([(Rat (true, 1, 1), [(0, 2)]), (Rat (false, 2, 1), [])],
   Rat (true, 0, 1), Rat (true, 5, 1)): real_alg

toString(alpha);
val it = "RealRoot(x0^2 + -2, [0,5])": string

val beta = mk_real_alg p Rat.zero Rat.zero;
Exception-
   NotUnique
  "Poly (x0^2 + -2) does not have exactly one root in the interval [0,0]."
   raised

val beta = mk_real_alg p (Rat.rat_of_int ~2) Rat.zero;
val beta =
   ([(Rat (true, 1, 1), [(0, 2)]), (Rat (false, 2, 1), [])],
    Rat (false, 2, 1), Rat (true, 0, 1)): real_alg

toString(beta);
val it = "RealRoot(x0^2 + -2, [-2,0])": string

val gamma = add(alpha, beta);
val gamma =
   ([(Rat (true, 1, 1), [(0, 3)]), (Rat (false, 8, 1), [(0, 1)])],
    Rat (false, 2, 1), Rat (true, 3, 2)): real_alg

toString(gamma);
val it = "RealRoot(x0^3 + -8 x0, [-2,3/2])": string

val q = [(Rat.one, [(0, 2)]), (Rat.rat_of_int ~3, [])] : poly;
p_toString q;
val it = "x0^2 + -3": string

val sigma = mk_real_alg q (Rat.one) (Rat.rat_of_int 10);
toString sigma;
val it = "RealRoot(x0^2 + -3, [1,10])": string

val gamma = add(alpha, sigma);

val r = [(Rat.one, [(0, 1)])] : poly;
p_toString r;

val beta = real_alg_of_rat (Rat.neg (Rat.rat_of_quotient (~19, 20)));
sign_at p beta;

*)

end