open Camlp4.PreCast
open Term
open Combinators 

module Monad =
struct 
  type 'a t = Error of Loc.t * string | Ok of 'a
  let return x = Ok x
  let (>>=) m f =
    match m with
    | Error (loc, msg) -> Error(loc, msg)
    | Ok v -> f v
  let error_loc loc msg = Error (loc, msg)
  let error msg = error_loc Loc.ghost msg

  let error_at loc = function
    | Error(_, msg) -> Error(loc, msg)
    | Ok v -> Ok v
end
open Monad

type env = (Term.var * Term.tp * int) list

let printvars env =
  begin
    Printf.printf "vars = [ ";
    List.iter (fun (name, _, _) -> Printf.printf "%s " name) env;
    Printf.printf "]\n";
  end

(* To look up variables we need to construct the n-th projection *)

let rec lookup _loc env x j =
  match env with
  | [] -> error_loc _loc (Printf.sprintf "Unbound variable '%s'" x)
  | (y, tp, i) :: env when x = y ->
      if i <= j then 
	return (compose pi2 (delay (j-i) tp), tp)
      else
	error_loc _loc (Printf.sprintf "Variable '%s' at time %d, used at time %d" x i j)
  | (y, _, _) :: env ->
      lookup _loc env x j >>= (fun (u, tp) ->
      return (compose pi1 u, tp))

(* Checking well-formedness of types *)


let rec string_of_type (InT(_, t)) =
  match t with 
  | One -> "one"
  | Prod(t, t') -> Printf.sprintf "(%s * %s)" (string_of_type t) (string_of_type t')
  | Arrow(t, t') -> Printf.sprintf "(%s -> %s)" (string_of_type t) (string_of_type t')
  | Sum(t, t') -> Printf.sprintf "(%s + %s)" (string_of_type t) (string_of_type t')
  | Gui t -> Printf.sprintf "Gui(%s)" (string_of_type t) 
  | Discrete -> "D(_)"
  | Stream t -> Printf.sprintf "S(%s)" (string_of_type t)
  | Next t   -> Printf.sprintf "Next(%s)" (string_of_type t)

let mismatch loc string (t1 : Term.tp) (t2 : Term.tp)  =
  error_loc loc (Printf.sprintf "%s: expected %s but got %s" string (string_of_type t1) (string_of_type t2))
      
let mismatch1 loc string (t1 : Term.tp) = 
  error_loc loc (Printf.sprintf "%s, but got %s" string (string_of_type t1))


(* Basically, what follows implements a simple bidirectional 
   typechecking algorithm, and along the way spits out a (well-typed)
   ML syntax tree. 
*)

let rec shrinkenv loc fvs = function
  | [] ->
      if Vars.is_empty fvs then
	return ([], one)
      else
	error_loc loc (Printf.sprintf "Freevars larger than context?") 
  | (x, tp, j) :: rest ->
      if Vars.mem x fvs then
	shrinkenv loc (Vars.remove x fvs) rest >>= (fun (env, wk) -> 
	return (((x, tp, j) :: env), map_prod wk id))
      else
	shrinkenv loc fvs rest >>= (fun (env, wk) ->
        return (env, compose pi1 wk))

let rec (++) (p, (InT(_loc, tp) as tp0), j) env =
  match p, tp with
  | PVar x, _  -> (x, tp0, j) :: env
  | PUnit, One -> env
  | PUnit, _  -> Loc.raise _loc (Failure (Printf.sprintf "Unit pattern, but got %s\n" (string_of_type tp0)))
  | PPair(p1, p2), Prod(tp1, tp2) -> (p2, tp2, j) ++ ((p1, tp1, j) ++ env)
  | PPair(_, _), _ -> Loc.raise _loc (Failure (Printf.sprintf "Pair pattern, but got %s\n" (string_of_type tp0)))
    
let rec env_type loc = function
  | [] -> InT(loc, One)
  | (x, tp, j) :: env ->
      let rec next = function
	| 0 -> tp
	| n -> InT(loc, Next (next (n-1)))
      in
      InT(loc, Prod(env_type loc env, next j))
    
let rec check env (In(_loc, e) as e0) (InT(tploc, tp) as tp0) j : ast t =
  match e, tp with
  | Unit, One -> return (compose one (delay j tp0))
  | Unit, _   -> mismatch1 _loc "Expected unit type" tp0
  | Pair(e1, e2), Prod(tp1, tp2) ->
      check env e1 tp1 j >>= (fun u1 ->
      check env e2 tp2 j >>= (fun u2 -> 
      return (compose (pair u1 u2) (from (next_prod j)))))
  | Pair(_,_), _ ->
      mismatch1 _loc "Expected product type" tp0
  | Lam(p, e2), Arrow(tp1, tp2) ->
      shrinkenv _loc (freevars e0) env >>= (fun (env, wk) ->
      check ((p, tp1, j) ++ env) e2 tp2 j >>= (fun u2 -> 
      return (compose wk (compose (curry (env_type _loc env) (compose (pattern p j) u2)) (from (next_exp j))))))
  | Lam(p, e2), _ ->
      mismatch1 _loc "Expected function type" tp0
  | Inject u, _ ->
      if j = 0 then
	return (compose one (fun _ _ -> u))
      else
	error_loc _loc "Can only inject constants at time 0"
  | Inl _, _
  | Inr _, _
  | Case(_, _, _, _, _), _ -> error_loc _loc "Sums not yet implemented"
  | Cons(e1, e2), Stream tp1 ->
      check env e1 tp1 j >>= (fun u1 -> 
      check env e2 tp0 (j+1) >>= (fun u2 -> 
      return (compose (compose (pair u1 u2) (from (next_prod j))) (map_next j cons))))
  | Cons(_, _), _ ->
      mismatch1 _loc "Expected stream type" tp0
  | Fix(x, e1), Stream tp1 ->
      check ((x, tp0, j+1) :: env) e1 tp0 j >>= (fun u1 ->
      return (fix_next (env_type _loc env) j (fix_stream tp1) u1))
  | Fix(x, e1), Arrow(tp1, tp2) ->
      check ((x, tp0, j+1) :: env) e1 tp0 j >>= (fun u1 ->
      return (fix_next (env_type _loc env) j fix_fun u1))
  | Fix(x, e1), Gui(InT(tploc, Stream tparg) as tp1) ->
      check ((x, tp1, j+1) :: env) e1 tp0 j >>= (fun u1 -> 
      return (fix_next (env_type _loc env) j (fix_guistream tp1) u1))
  | (NextE e1), (Next tp1 ) ->
      check env e1 tp1 (j+1)
  | NextE _, _ -> 
      mismatch1 _loc "Expected delay type" tp0
  | Let(p, e1, e0), _ ->
      synth env e1 j >>= (fun (u1, tp1) -> 
      check ((p, tp1, j) ++ env) e0 tp0 j >>= (fun u0 -> 
      return (compose (pair id u1) (compose (pattern p j) u0))))
  | Zip(e1, e2), Stream(InT(_, Prod(tp1_elt, tp2_elt))) ->
      check env e1 (InT(_loc, Stream tp1_elt)) j >>= (fun u1 -> 
      check env e2 (InT(_loc, Stream tp2_elt)) j >>= (fun u2 -> 
      let u3 = pair u1 u2 in
      let u3 = compose u3 (from (next_prod j)) in
      let u3 = compose u3 (into zip) in
      return u3))
  | GuiReturn e1, Gui(tp1) ->
      check env e1 tp1 j >>= (fun u1 ->
      return (compose u1 (map_next j Combinators.return)))
  | GuiReturn _, _ ->
      mismatch1 _loc "Expected Gui type" tp0
  | LetGui(p, e1, e2), Gui(_) ->
      synth env e1 j >>= (fun (u1, InT(tploc, tp1)) -> 
      match tp1 with
      | Gui(tp1') -> 
	  check ((p, tp1', j) ++ env) e2 tp0 j >>= (fun u2 ->
          return (compose
		    (pair id u1)
		    (compose
		       (from (next_prod j))
		       (map_next j (compose strength (bind (compose (pattern p j) u2)))))))
      | _ -> mismatch1 _loc "Expected GUI type in binder" (InT(tploc, tp1)))
  | LetGui(p, e1, e2), _ ->
      mismatch1 _loc "Expected GUI type in body" tp0
  | e, _ ->
      synth env e0 j >>= (fun (u, tp0') ->
      if tp_equal tp0 tp0' then
	  return u
      else
	mismatch _loc "check" tp0 tp0')

and synth  env (In(_loc, e)) j =
   match e with
   | Var x -> lookup _loc env x j 
   | Embed u ->
       let tp = InT(_loc, Discrete) in
       return (compose (embed u) (delay j tp), tp)
   | Fst e1 ->
       synth env e1 j >>= (fun (u1, (InT(tploc, tp) as tp0)) -> 
       match tp with
       | Prod(tp, _) ->
	   return (compose u1 (map_next j pi1), tp)
       | Discrete ->
	   return (compose u1 (map_next j (compose (from pair_discrete) pi1)), tp0)
       | _ -> mismatch1 _loc "fst: expected product type" tp0)
   | Snd e1 ->
       synth env e1 j >>= (fun (u1, (InT(tploc, tp) as tp0)) -> 
       match tp with
       | Prod(_, tp) ->
	   return (compose u1 (map_next j pi2), tp)
       | Discrete ->
	   return (compose u1 (map_next j (compose (from pair_discrete) pi2)), tp0)
       | _ -> mismatch1 _loc "snd: expected product type" tp0)
   | Unfold e1 ->
       synth env e1 j >>= (fun (u1, (InT(tploc, tp1) as tp0)) ->
       match tp1 with
       | Arrow(tp1, InT(_, Prod(tp2, InT(_, Next tp1')))) ->
	   if tp_equal tp1 tp1' then
	     let tp = InT(tploc, Arrow(tp1, InT(tploc, Stream tp2))) in
	     let u = compose u1 (map_next j unfold) in
	     return (u, tp)
	   else
	     mismatch _loc "unfold -- impossible mismatch" tp1 tp1'
       | _ -> mismatch1 _loc "unfold -- expected function" tp0)
   | App(e1, e2) ->
       synth env e1 j >>= (fun (u1, (InT(tploc, tp) as tp1)) ->
       match tp with
       | Arrow(tp2, tp3) ->
	   check env e2 tp2 j >>= (fun u2 -> 
           return (compose (pair u1 u2) (compose (from (next_prod j)) (map_next j eval)),
		   tp3))
       | Discrete ->
	   check env e2 (InT(tploc, Discrete)) j >>= (fun u2 -> 
	   return ((compose (pair (compose u1 (map_next j exp_discrete)) u2)
                           (compose (from (next_prod j)) (map_next j eval))),
		   InT(tploc, Discrete)))
       | _ -> mismatch1 _loc "synth app: expected function type" tp1)
   | Head e1 ->
       synth env e1 j >>= (fun (u1, (InT(tploc, tp) as tp1)) ->
       match tp with
       | Stream tp2 -> return (compose u1 (map_next j head), tp2)
       | _ -> mismatch1 _loc "synth head: expected stream type" tp1)
   | Tail e1 ->
       if j > 0 then
	 synth env e1 (j-1) >>= (fun (u1, (InT(tploc, tp) as tp1)) ->
         match tp with
	 | Stream _ -> return (compose u1 (map_next (j-1) tail), tp1)
	 | _ -> mismatch1 _loc "synth tail: expected stream type" tp1)
       else
	 error_loc _loc "Can't create tail at time 0"
   | Map(e1, e2) ->
       synth env e1 j >>= (fun (u1, (InT(tploc, tp) as tp1)) ->
       match tp with
       | Arrow(tp2, tp3) ->
	   check env e2 (InT(_loc, Stream tp2)) j >>= (fun u2 -> 
  	   let u3 = pair u1 u2  in 
	   let u3 = compose u3 (from (next_prod j)) in 
	   let u3 = compose u3 (map_next j stream_strength) in 
	   return (u3, InT(_loc, Stream tp3)))
       | _ -> mismatch1 _loc "Expected function type" tp1)
   | Await e1 ->
       if j > 0 then
	 let k = j-1 in 
	 synth env e1 k >>= (fun (u1, (InT(tploc, tp) as tp1)) ->
         match tp with
	 | Next tp2 -> return (u1, tp2)
	 | _ -> mismatch1 _loc "synth await: expected next type" tp1)
       else
	 error_loc _loc "Attempted to await at time 0"
   | Let(p, e1, e2) ->
      synth env e1 j >>= (fun (u1, tp1) -> 
      synth ((p, tp1, j) ++ env) e2 j >>= (fun (u2, tp2) ->
	return (compose (pair id u1) (compose (pattern p j) u2), tp2)))
   | LetGui(p, e1, e2) ->
      synth env e1 j >>= (fun (u1, InT(tploc, tp1)) -> 
      match tp1 with
      | Gui(tp1') -> 
	  synth ((p, tp1', j) ++ env) e2 j >>= (fun (u2, (InT(tploc, tp2) as tp2')) ->
          match tp2 with
          | Gui _ ->
		return (compose
		  	  (pair id u1)
			  (compose
			     (from (next_prod j))
			     (map_next j (compose strength (bind (compose (pattern p j) u2))))),
			  tp2')
	  | _ -> mismatch1 _loc "Expected GUI type" tp2')
      | _ -> mismatch1 _loc "Expected GUI type" (InT(tploc, tp1)))
   | Annot(e1, tp1) ->
       check env e1 tp1 j >>= (fun u -> return (u, tp1))
   | _ -> error_loc _loc "synth: Cannot synthesize type for checking term"
         

(* The elaborate functions try and elaborate terms in the empty context *)

let elaborate e _loc =
  match synth [] e 0 with
  | Ok(u, tp) -> u "Dsl" _loc
  | Error(loc, msg) -> Loc.raise loc (Failure msg)


