open Camlp4.PreCast
  
open Syntax
open Elaborator
open Term
open Grammar  

(* The payoff for piling all the syntactic categories into one big soup 
   comes here -- it makes parsing very simple, and lets us generate somewhat
   better error messages (since we can use semantic info to identify problems). 
*)
            
EXTEND Gram
  GLOBAL: expr mtype mexpr ;

  mtype:
    [ RIGHTA
      [ tp1 = mtype; "->"; tp2 = mtype            -> InT(_loc, Arrow(tp1, tp2))
      | tp1 = mtype; "*"; tp2 = mtype             -> InT(_loc, Prod(tp1, tp2))]
    | [ LIDENT "one"                              -> InT(_loc, One)
      | UIDENT "Next"; "("; tp = mtype; ")"       -> InT(_loc, Next(tp))
      | UIDENT "S"; "("; tp = mtype; ")"          -> InT(_loc, Stream(tp))
      | UIDENT "D"; "("; tp = ctyp; ")"           -> InT(_loc, Discrete)
      | UIDENT "Gui"; "("; tp = mtype; ")"         -> InT(_loc, Gui(tp)) 
      | "("; tp = mtype; ")"                      -> tp ]
    ];

  mpat:
    [ [ "("; ")" -> (PUnit, Some (InT(_loc, One)))
      | "("; (p1, tp1) = mpat; ","; (p2, tp2) = mpat; ")" ->
	  (match tp1, tp2 with
	   | Some tp1, Some tp2 -> (PPair(p1, p2), Some (InT(_loc, Prod(tp1, tp2))))
	   | None, None         -> (PPair(p1, p2), None)
	   | _, _               -> Loc.raise _loc (Failure "All or no pattern variables should be annotated"))
      | "("; ptp = mpat; ")" -> ptp
      | v = LIDENT; ":"; tp = mtype -> (PVar v, Some tp)
      | v = LIDENT -> (PVar v, None)
      ]
    ];

  mexpr: [ "binders"
           [ "fun"; ps = LIST1 [(p, _) = mpat -> p]; "->"; body = mexpr ->
               mk_fun _loc ps body
	   | "let"; "rec"; x = LIDENT; ":"; tp1 = mtype; "="; e = mexpr; "in"; e' = mexpr ->
	       In(_loc, Let((PVar x), In(_loc, Annot(In(_loc, Fix(x, e)), tp1)), e'))
	   | "let"; "rec"; f = LIDENT; (p, tp1') = mpat; 
	     ":"; UIDENT "S"; "("; tp2 = mtype; ")"; "="; e = miter; "in"; e' = mexpr ->
		  (match tp1' with
		   | None -> Loc.raise _loc (Failure "Argument needs type annotation")
		   | Some tp1 -> 
		       let e = e f in
		       let tp = InT(_loc, Arrow(tp1, InT(_loc, Prod(tp2, InT(_loc, Next tp1))))) in 
		       In(_loc, Let((PVar f),
				    In(_loc, Unfold(In(_loc, Annot(mk_fun _loc [p] e, tp)))),
				    e')))
	   | "let"; LIDENT "gui"; (p, tp') = mpat; "="; e = mexpr; "in"; e' = mexpr ->
	       (match tp' with
		| None    -> In(_loc, LetGui(p, e, e'))
		| Some tp -> In(_loc, LetGui(p, In(_loc, Annot(e, InT(_loc, Gui(tp)))), e')))
           | "let"; (p, tp') = mpat; "="; e = mexpr; "in"; e' = mexpr ->
	       (match tp' with
		| None -> In(_loc, Let(p, e, e'))
		| Some tp -> In(_loc, Let(p, In(Loc.merge (loc_exp e) (loc_tp tp), Annot(e, tp)), e')))
           ] 
         | "annotation"
           [ e = mexpr; ":";  t = mtype -> In(_loc, Annot(e, t)) ]
         | "application"
           [ LIDENT "head"; e = mexpr -> In(_loc, Head(e))
	   | LIDENT "tail"; e = mexpr -> In(_loc, Tail(e))
           | LIDENT "await"; e = mexpr -> In(_loc, Await(e))
           | LIDENT "next"; e = mexpr -> In(_loc, NextE(e))
	   | LIDENT "fst"; e = mexpr -> In(_loc, Fst(e))
	   | LIDENT "snd"; e = mexpr -> In(_loc, Snd(e))
	   | LIDENT "return"; e = mexpr -> In(_loc, GuiReturn e)
           | LIDENT "cons"; "(";  e = mexpr; ","; e' = mexpr; ")" -> In(_loc, Cons(e, e'))
	   | LIDENT "map"; "("; e = mexpr; ","; e' = mexpr; ")" -> In(_loc, Map(e, e'))
	   | LIDENT "zip"; "("; e = mexpr; ","; e' = mexpr; ")" -> In(_loc, Zip(e, e'))
           |  m = mexpr; m' = mexpr -> In(_loc, App(m, m'))  ]
         | "atomic"
           [ v = LIDENT -> In(_loc, Var v)
           |  "("; ")" -> In(_loc, Unit)
           |  "("; e = mexpr; ")" -> e
           |  "("; e = mexpr; ","; e' = mexpr; ")" -> In(_loc, Pair(e, e'))
           |  "["; e = expr; "]" -> In(_loc, Embed(e))
	   |  "{"; e = expr; "}" -> In(_loc, Inject(e))
           ]
         ];

  miter:
    [ [ LIDENT "cons"; "("; e1 = mexpr; ","; g = LIDENT; e2 = mexpr; ")" ->
	  (fun f ->
	    if f = g then
	      In(_loc, Pair(e1, In(_loc, NextE e2)))
	    else
	      Loc.raise _loc (Failure (Printf.sprintf "Expected function name '%s', got '%s'" f g)))
      | "let"; (p, tp') = mpat; "="; e = mexpr; "in"; e' = miter ->
	  (match tp' with
	   | None -> fun f -> In(_loc, Let(p, e, e' f))
	   | Some tp -> fun f -> 
               In(_loc, Let(p, In(Loc.merge (loc_exp e) (loc_tp tp), Annot(e, tp)), e' f)))
      ]
    ];

  expr: LEVEL "top"
    [ [ "do"; UIDENT "U"; "("; m = mexpr; ")" -> elaborate m _loc
      ] 
    ];
END
