
fun not_yet s = Library.internal_error ("not yet implemented : " ^ s)


val zero = Tree.CONST 0
val one = Tree.CONST 1
  
fun seq([]) = Library.internal_error "seq : given empty list!"
  | seq([s]) = s
  | seq(h::t) = Tree.SEQ(h, seq(t))

val tr_skip          = Tree.Ex(zero) 
fun tr_integer n     = Tree.Ex(Tree.CONST n) 
fun tr_boolean true  = Tree.Cx(fn(t, f) => Tree.JUMP(Tree.NAME t, []))
  | tr_boolean false = Tree.Cx(fn(t, f) => Tree.JUMP(Tree.NAME f, []))

fun tr_unary_op oper e = 
    case oper of 
      AST_expr.Neg => Tree.Ex(Tree.BINOP(Tree.MINUS, zero, Tree.unEx e))
    | AST_expr.Not => Tree.Cx(fn(t, f) => Tree.CJUMP(Tree.NE, Tree.unEx e , zero, f, t))

fun tr_op oper e1 e2 = 
    case oper of 
      AST_expr.Plus => Tree.Ex(Tree.BINOP(Tree.PLUS, Tree.unEx e1, Tree.unEx e2))
    | AST_expr.Mult => Tree.Ex(Tree.BINOP(Tree.MUL, Tree.unEx e1, Tree.unEx e2))
    | AST_expr.Subt => Tree.Ex(Tree.BINOP(Tree.MINUS, Tree.unEx e1, Tree.unEx e2))
    | AST_expr.GTEQ => Tree.Cx(fn(t, f) => Tree.CJUMP(Tree.GE, Tree.unEx e1 , Tree.unEx e2, t, f))
    | AST_expr.EQ   => Tree.Cx(fn(t, f) => Tree.CJUMP(Tree.EQ, Tree.unEx e1 , Tree.unEx e2, t, f))

fun tr_seq [] = Tree.Ex zero 
  | tr_seq [e] = e
  | tr_seq (e :: rest) = Tree.Ex (Tree.ESEQ (Tree.unNx e, Tree.unEx (tr_seq rest)))

(* See page 165--164 of Appel's book.  But note that 
   Appel's Tiger language has a "break" statement which 
   complicates loops.  Our Slang languages do not have
   this construct. 
*) 

fun tr_while loop_condition loop_body = 
    let
      val cond = Tree.unCx loop_condition 
      val body = Tree.unNx loop_body
      val cond_label = Temp.newlabel()
      val body_label = Temp.newlabel() 
      val end_label  = Temp.newlabel() 
    in 
      Tree.Nx (seq [Tree.LABEL cond_label, 
                  cond (body_label, end_label),
                  Tree.LABEL body_label,
                  body,
                  Tree.JUMP (Tree.NAME cond_label, [cond_label]),
                  Tree.LABEL end_label])
    end

(* See page 161--163 of Appel's book *) 
fun tr_if e1 then_e else_e  =
      let
        val cond = Tree.unCx(e1)
        val then_label = Temp.newlabel()
        val else_label = Temp.newlabel()
      in
        case (then_e, else_e) of 
          (Tree.Ex _, Tree.Ex _) =>
          let val r = Temp.newtemp()
              val joinLabel = Temp.newlabel()
          in 
          Tree.Ex (Tree.ESEQ (seq [cond (then_label, else_label),
                              Tree.LABEL then_label,
                              Tree.MOVE (Tree.TEMP r, Tree.unEx then_e),
                              Tree.JUMP (Tree.NAME joinLabel, [joinLabel]),
                              Tree.LABEL else_label,
                              Tree.MOVE (Tree.TEMP r, Tree.unEx else_e),
                              Tree.LABEL joinLabel],
                      Tree.TEMP r))
          end 
        | (Tree.Cx _, Tree.Cx _) =>
            Tree.Cx (fn (t, f) =>
                let val then_exp = (Tree.unCx then_e) (t, f)
                    val else_exp = (Tree.unCx else_e) (t, f)
                in 
                   seq [cond (then_label, else_label),
                           Tree.LABEL then_label,
                           then_exp,
                           Tree.LABEL else_label,
                           else_exp]
                end)
        | (Tree.Nx _, Tree.Nx _) =>
          let val joinLabel = Temp.newlabel()
          in 
          Tree.Nx (seq [(cond) (then_label, else_label),
                   Tree.LABEL then_label,
                   Tree.unNx then_e,
                   Tree.JUMP (Tree.NAME joinLabel, [joinLabel]),
                   Tree.LABEL else_label,
                   Tree.unNx else_e,
                   Tree.LABEL joinLabel])
          end 
        | (Tree.Ex _, Tree.Cx _) => tr_if e1 then_e (Tree.Ex (Tree.unEx else_e)) 
        | (Tree.Ex _, Tree.Nx _) => tr_if e1 (Tree.Nx (Tree.unNx then_e)) else_e
        | (Tree.Cx _, Tree.Ex _) => tr_if e1 (Tree.Ex (Tree.unEx then_e)) else_e 
        | (Tree.Cx _, Tree.Nx _) => tr_if e1 (Tree.Nx (Tree.unNx then_e)) else_e
        | (Tree.Nx _, Tree.Ex _) => tr_if e1 then_e (Tree.Nx (Tree.unNx else_e)) 
        | (Tree.Nx _, Tree.Cx _) => tr_if e1 then_e (Tree.Nx (Tree.unNx else_e)) 
    end

(* ugly! *)       
fun tr_print e = Tree.Ex(Tree.CALL(Tree.NAME("print"), [Tree.unEx e]))       
      
fun tr_ref e   = e 
fun tr_deref e = e 
fun tr_assign e1 e2 = Tree.Nx (Tree.MOVE (Tree.unEx e1, Tree.unEx e2))  

fun tr_let x e1 e2 = 
       Tree.Ex(Tree.ESEQ(Tree.unNx (tr_assign (Tree.Ex(Tree.TEMP x)) e1), Tree.unEx e2))

fun tr_closure f el = Tree.Ex(Tree.CLOSURE(f, List.map Tree.unEx el))
fun tr_var v = Tree.Ex(Tree.TEMP v)

fun tr_direct v  = Tree.Ex(Tree.NAME v)
fun tr_appdir f el = Tree.Ex(Tree.CALL(Tree.NAME f, List.map Tree.unEx el))
fun tr_appcls f el = Tree.Ex(Tree.CALL(Tree.NAME f, List.map Tree.unEx el))
fun tr_app e el = Tree.Ex(Tree.CALL(Tree.unEx e, List.map Tree.unEx el))

(*   *) 

open AST_closure 

(* really should have a Set module ... *) 
val union = AST_expr.union 
val empty_set = AST_expr.empty_set 
val singleton = AST_expr.singleton
val difference = AST_expr.difference
val list_to_set = AST_expr.list_to_set

fun trans (Skip)              = tr_skip
  | trans (Integer n)         = tr_integer n
  | trans (Boolean b)         = tr_boolean b
  | trans (Deref e)           = tr_deref (trans e)
  | trans (UnaryOp (uop, e))  = tr_unary_op uop (trans e)
  | trans (Op (oper, e1, e2)) = tr_op oper (trans e1) (trans e2)
  | trans (Assign(e1, e2))    = tr_assign (trans e1) (trans e2)
  | trans (Seq(e1, e2))       = tr_seq ((trans_list e1) @ (trans_list e2))
  | trans (If(e1, e2, e3))    = tr_if (trans e1) (trans e2) (trans e3) 
  | trans (While(e1, e2))     = tr_while (trans e1) (trans e2)
  | trans (Ref e)             = tr_ref (trans e)
  | trans (Var v)             = tr_var v 
  | trans (AppDir("print", [e])) = tr_print (trans e) 
  | trans (AppDir(f, el))     = tr_appdir f (List.map trans el)
  | trans (AppCls(f, el))     = tr_appcls f (List.map trans el)
  | trans (App(e, el))        = tr_app (trans e) (List.map trans el)
  | trans (Let(v, e1, e2))    = tr_let v (trans e1) (trans e2)
  | trans (Closure (f, el))   = tr_closure f (List.map trans el)
  | trans (Direct f)          = tr_direct f 

and trans_list (Seq(e1, e2))  = (trans_list e1) @ (trans_list e2)
  | trans_list e              = [trans e] 

fun trans_expr e  = Tree.unNx (trans e) 

fun get_temps_stm (Tree.SEQ(s1, s2))             = union(get_temps_stm s2, get_temps_stm s2)
  | get_temps_stm (Tree.LABEL l)                 = empty_set 
  | get_temps_stm (Tree.JUMP(e, _))              = get_temps_exp e 
  | get_temps_stm (Tree.CJUMP (_, e1, e2, _, _)) = union(get_temps_exp e1, get_temps_exp e2)
  | get_temps_stm (Tree.MOVE (e1, e2))           = union(get_temps_exp e1, get_temps_exp e2)
  | get_temps_stm (Tree.EXP e)                   = get_temps_exp e 

and get_temps_exp (Tree.BINOP(_, e1, e2)) = union(get_temps_exp e1, get_temps_exp e2)
  | get_temps_exp (Tree.MEM e)            = get_temps_exp e
  | get_temps_exp (Tree.TEMP t)           = singleton t 
  | get_temps_exp (Tree.ESEQ (s, e))      = union(get_temps_stm s, get_temps_exp e)
  | get_temps_exp (Tree.NAME l)           = empty_set 
  | get_temps_exp (Tree.CONST _)          = empty_set 
  | get_temps_exp (Tree.CALL(e, el))      = union(get_temps_exp e, get_temps_exp_list el)
  | get_temps_exp (Tree.CLOSURE(_, _))    = empty_set 

and get_temps_exp_list []        = empty_set 
  | get_temps_exp_list [e]       = get_temps_exp e 
  | get_temps_exp_list (e::rest) = union(get_temps_exp e,  get_temps_exp_list rest)

fun trans_fun (FunDecl(name, _, arg, env, body)) = 
    let 
        val body' = trans_expr body 
        val fvs = difference(get_temps_stm body', list_to_set (arg::env))
        val locals = Binaryset.listItems fvs 
    in 
       Tree.FunDecl(name, arg, env, locals, body')
    end 

fun translate (Prog(fun_list, f)) = Tree.Prog(List.map trans_fun fun_list, f) 
     
