open OcamlTypes               (* Coq extraction *)  (* for ast_to_ocmType *) 
open Specif                   (* Coq extraction *)  (* for Coq_existT *) 
open Semantics                (* Coq extraction *)
open DecSetoid                (* Coq extraction *)
open Bisemigroup              (* Coq extraction *)
open BisemigroupPropRecord    (* Coq extraction *)
open SemigroupProperties      (* Coq extraction *)
open SemigroupPropRecord      (* Coq extraction *)

open Errors 
open CoqInterface
open MreAST
open Parsers

type algorithmType =
  | MatrixMultiplication
  | Dijkstra
  | BellmanFord

(*---------------------------------------------------------------------------*)
(*                              Graph algorithm                              *)
(*---------------------------------------------------------------------------*)
let rec trim s =
  let l = String.length s in 
  if l=0 then s
  else if s.[0]=' ' || s.[0]='\t' || s.[0]='\n' || s.[0]='\r' then
    trim (String.sub s 1 (l-1))
  else if s.[l-1]=' ' || s.[l-1]='\t' || s.[l-1]='\n' || s.[l-1]='\r' then
    trim (String.sub s 0 (l-1))
  else
    s

type graph = {
   vnum  : int;
   adj   : (int * wfAst) list array
}

(*
 * Input file format:
 *   1:  [algebra specification]
 *   2:  [number of vertexes in the graph]
 *   3:  [source of edge_1] : [target of edge_1] : [label of edge_1]
 * ...
 * k+2:  [source of edge_k] : [target of edge_k] : [label of edge_k]
 *)

let readInputGraph graphFile =
  let f = open_in graphFile in
  let term = parseTerm (input_line f) in  (* algebra specification *)
  let tp = OcamlTypes.otLang term in
  let n = int_of_string (input_line f) in (* number of vertexes *)
  let a = Array.make n [] in
  (try
    while (true) do
      let ln = trim (input_line f) in
      if not (ln = "")
        then (
        try
          let l = List.map trim (Str.split (Str.regexp_string ":") ln) in
          let s = int_of_string (List.nth l 0) in
          let t = int_of_string (List.nth l 1) in
          let e = parseCarrier tp (List.nth l 2) in
          a.(s) <- (t, e) :: a.(s)
        with Failure _ ->
          Printf.printf "Graph input error. Line: '%s'.\n" ln
        )
    done
  with End_of_file -> ());
  close_in f;
  (term, tp, { vnum = n; adj = a }) 


(*---------------------------------------------------------------------------*)
(*                      Matrix multiplication algorithm                      *)
(*---------------------------------------------------------------------------*)

let runMatrixMultiplication sem g =
  (* get semantics *)
  let bs, props, tp = 
    match sem with
      | BsSem (Coq_existT (s, p), tp) ->  s, p, tp
      | SemErr y -> raise (NoSemantics (string_of_lang y.exp, getString y.errMsg))
      | _ -> raise (Failure "We need a bisemigroup!")
  in
  (* get times identity and annihilator *)
  let alpha = 
     match props.bs_times_sgprop.hasIdentity with
       | Some (Coq_existT (a, _)) -> a
       | None -> raise (AlgorithmsRequirementsFailed "The specified algebra does not have an identity element for times.")
  in
  let omega = 
     match props.bs_plus_sgprop.hasIdentity with
       | Some (Coq_existT (a, _)) -> a
       | None -> raise (AlgorithmsRequirementsFailed "The specified algebra does not have an identity element for plus.")
  in
  (* create initial matrix *)
  let m = Array.make_matrix g.vnum g.vnum omega in
  for i = 0 to g.vnum-1 do
    m.(i).(i) <- alpha;
    List.iter (fun (j, e) -> m.(i).(j) <- fromAst tp e
      (*match ast_to_ocmType tp e with
        | Coq_inl e -> e
        | Coq_inr (NotWT (x, t)) -> raise (WrongType ("The term '" ^ (string_of_ast x) ^ "' is expected to have a type " ^ (string_of_ocamlTypes t)))*)
    ) g.adj.(i)
  done;
  (* carrier operations *)
  let c_eq = bs.setoid.equal in
  let c_plus = bs.plus in
  let c_times = bs.times in
  (* matrix operations *)
  let m_eq a b = 
    let ret = ref true in 
    for i = 0 to g.vnum-1 do
      for j = 0 to g.vnum-1 do
        if not (c_eq a.(i).(j) b.(i).(j)) then ret := false
      done
    done;
    !ret
  in
  let m_times a b =
    let c = Array.make_matrix g.vnum g.vnum omega in
    for i = 0 to g.vnum-1 do
      for j = 0 to g.vnum-1 do
        c.(i).(j) <- c_times a.(i).(0) b.(0).(j);
        for k = 1 to g.vnum-1 do
          c.(i).(j) <- c_plus c.(i).(j) (c_times a.(i).(k) b.(k).(j))
        done
      done
    done;
    c
  in
  let print_matrix m =
    for i = 0 to g.vnum-1 do
      for j = 0 to g.vnum-1 do
         Printf.printf "%3d --> %3d : %s\n" i j (string_of_carrier tp m.(i).(j))
      done
    done
  in
  (* run the algorithm *)
  let last = ref m in
  let current = ref (m_times !last m) in
  while not (m_eq !last !current) do
     Printf.printf "Current matrix:\n";
     print_matrix !current;
     let next = m_times !current m in
     last := !current;
     current := next
  done;
  (* print matrix *)
  (*let string_of_carrier x = string_of_ast (ocmType_to_ast tp x) in*)
  Printf.printf "\nThe solution computed by matrix multiplication algorithm:\n";
  for i = 0 to g.vnum-1 do
    for j = 0 to g.vnum-1 do
       Printf.printf "%3d --> %3d : %s\n" i j (string_of_carrier tp !current.(i).(j))
    done
  done

(*---------------------------------------------------------------------------*)
(*                           Dijkstra's algorithm                            *)
(*---------------------------------------------------------------------------*)

let runDijkstra startNode sem g =
  (* get semantics *)
  let bs, props, tp = 
    match sem with
      | BsSem (Coq_existT (s, p), tp) ->  s, p, tp
      | SemErr y -> raise (NoSemantics (string_of_lang y.exp, getString y.errMsg))
      | _ -> raise (Failure "We need a bisemigroup!")
  in
  (* check selectivity *)
  (match props.bs_plus_sgprop.isSelective with
     | Some _ -> ()
     | None -> raise (AlgorithmsRequirementsFailed "The specified algebra is not selective."));
  (* check increasingness *)
  (match props.leftIncreasing with
     | Some _ -> ()
     | None -> raise (AlgorithmsRequirementsFailed "The specified algebra is not left increasing."));
  (* get times identity and annihilator *)
  let alpha = 
     match props.bs_times_sgprop.hasIdentity with
       | Some (Coq_existT (a, _)) -> a
       | None -> raise (AlgorithmsRequirementsFailed "The specified algebra does not have an identity element for times.")
  in
  let omega = 
     match props.bs_times_sgprop.hasAnnihilator with
       | Some (Coq_existT (a, _)) -> a
       | None -> raise (AlgorithmsRequirementsFailed "The specified algebra does not have an identity element for times.")
  in
  (* create adjacency lists *)
  let adj = Array.map (List.map (fun (x, e) -> (x, fromAst tp e)
                       (*match ast_to_ocmType tp e with
                         | Coq_inl e -> (x, e)
                         | Coq_inr (NotWT (x, t)) -> raise (WrongType ("The term '" ^ (string_of_ast x) ^ "' is expected to have a type " ^ (string_of_ocamlTypes t)))  *)
                      )) g.adj
  in
  (* carrier operations *)
  let c_eq = bs.setoid.equal in
  let c_plus = bs.plus in
  let c_times = bs.times in
  let c_lt x y = (c_eq (c_plus x y) x) && not (c_eq (c_plus y x) y) in
  (* Dijktra's algorithm *)
  let dist = Array.make g.vnum omega in
  dist.(startNode) <- alpha;
  let rec make_q n = if n >= 0 then n :: (make_q (n-1)) else [] in
  let q = ref (make_q (g.vnum-1)) in
  while not (!q = []) do
    (* find closest uninvestigated vertex *)
    let u = List.fold_left (fun x res -> if c_lt dist.(x) dist.(res) then x else res) (List.hd !q) (List.tl !q) in
    if c_eq dist.(u) omega
      then q := []
      else begin
      	(* remove u *)
        q := List.filter (fun x -> not (x = u)) !q;
        (* relax neighbours *)
        List.iter (fun (v, e) -> 
                      let alt = c_times e dist.(u) in
                      if c_lt alt dist.(v) then dist.(v) <- alt
                  ) adj.(u)
      end
  done;
  (* print distances *)
  (*let string_of_carrier x = string_of_ast (ocmType_to_ast tp x) in*)
  Printf.printf "\nThe solution computed by Dijstra's algorithm:\n";
  for i = 0 to g.vnum-1 do
     Printf.printf "%3d --> %3d : %s\n" startNode i (string_of_carrier tp dist.(i))
  done

  
(*---------------------------------------------------------------------------*)
(*                         Bellman-Ford algorithm                            *)
(*---------------------------------------------------------------------------*)

let runBellmanFord sem g =
  raise (InternalError "Bellman-Ford algorithm not yet implemented")

let runPathAlgorithm algorithm startNode graphFile =
  let (term, tp, g) = readInputGraph graphFile in 
  let sem = getSem term in 
  match algorithm with
    | MatrixMultiplication -> runMatrixMultiplication sem g
    | Dijkstra -> runDijkstra startNode sem g
    | BellmanFord -> runBellmanFord sem g

