
open AST_closure; 

val free_vars = AST_expr.free_vars 
val delete    = AST_expr.delete 
val member    = AST_expr.member

fun conv (AST_expr.Skip)              = ([], Skip) 
  | conv (AST_expr.Integer n)         = ([], Integer n)
  | conv (AST_expr.Boolean b)         = ([], Boolean b) 
  | conv (AST_expr.Deref e)           = 
    let val (fun_list', e') = conv e
    in 
       (fun_list', Deref e')
    end 
  | conv (AST_expr.UnaryOp (uop, e))  = 
    let val (fun_list', e') = conv e
    in 
       (fun_list', UnaryOp (uop, e'))
    end 
  | conv (AST_expr.Op (oper, e1, e2)) = 
    let val (fun_list1, e1') = conv e1
        val (fun_list2, e2') = conv e2
    in 
       (fun_list1 @ fun_list2, Op (oper, e1', e2'))
    end 
  | conv (AST_expr.Assign(e1, _, e2))    = 
    let val (fun_list1, e1') = conv e1
        val (fun_list2, e2') = conv e2
    in 
       (fun_list1 @ fun_list2, Assign(e1', e2'))
    end 
  | conv (AST_expr.Seq(e1, e2))       = 
    let val (fun_list1, e1') = conv e1
        val (fun_list2, e2') = conv e2
    in 
       (fun_list1 @ fun_list2, Seq(e1', e2'))
    end 
  | conv (AST_expr.If(e1, e2, e3))    = 
    let val (fun_list1, e1') = conv e1
        val (fun_list2, e2') = conv e2
        val (fun_list3, e3') = conv e3
    in 
       (fun_list1 @ fun_list2 @ fun_list3, If(e1', e2', e3'))
    end 
  | conv (AST_expr.While(e1, e2))     = 
    let val (fun_list1, e1') = conv e1
        val (fun_list2, e2') = conv e2
    in 
       (fun_list1 @ fun_list2, While(e1', e2'))
    end 
  | conv (AST_expr.Ref e)             = 
    let val (fun_list, e') = conv e
    in 
       (fun_list, Ref e')
    end 
  | conv (AST_expr.App(e1, e2))       = 
    let val (fun_list1, e1') = conv e1
        val (fun_list2, e2') = conv e2
    in 
      (fun_list1 @ fun_list2, App(e1', [e2']))
    end 
  | conv (AST_expr.Var v)             = ([], Var v)
  | conv (AST_expr.Let(f,fun_t, AST_expr.Fn (x, arg_t, e1), e2)) = conv (AST_expr.Letrecfn(f, fun_t, x, arg_t, e1, e2))
  | conv (AST_expr.Let(x, t, e1, e2)) = 
    let val (fun_list1, e1') = conv e1
        val (fun_list2, e2') = conv e2
    in 
        (fun_list1 @ fun_list2, Let(x, e1', e2'))
    end 
  | conv (AST_expr.Fn (x, t, e))        = 
    let val f = Temp.newtemp() 
        val fv = Binaryset.listItems (delete(free_vars e, x)) 
        val (fun_list, e') = conv e
        val new_fun = FunDecl(f, false, x, fv, e') 
        val fun_value = if fv = [] then Direct f else Closure(f, List.map (fn z => Var z) fv) 
    in 
        (new_fun :: fun_list, fun_value) 
    end 
  | conv (AST_expr.Letrecfn(f, fun_t, x, arg_t, e1, e2)) = 
    let 
        val fv_e1 = free_vars e1
        val is_rec = member(fv_e1, f)
        val fv = Binaryset.listItems (delete(delete(fv_e1, x), f))
        val fun_value = if fv = [] then Direct f else Closure(f, List.map (fn z => Var z) fv) 
        val (fun_list1, e1') = conv e1 
        val new_fun = FunDecl(f, is_rec, x, fv, e1') 
        val (fun_list2, e2') = conv e2 
    in 
        (new_fun :: (fun_list1 @ fun_list2), Let(f, fun_value, e2'))
    end 

fun convert e = 
    let val (fun_list, e') = conv e 
        val f = "_MAIN" 
        val dummy = "_DUMMY"
        val main_fun = FunDecl(f, false, dummy, [], e')
    in 
	Prog (main_fun :: fun_list, f)
    end 

(* simplificaton *) 

(* constant folding and other local simplifications *) 
fun fold (UnaryOp (AST_expr.Neg, Integer n))  = Integer (0 - n) 
  | fold (UnaryOp (AST_expr.Not, Boolean true))  = Boolean false 
  | fold (UnaryOp (AST_expr.Not, Boolean false))  = Boolean true
  | fold (UnaryOp (uop, e))  = UnaryOp (uop, fold e)
  | fold (Op (AST_expr.Plus, Integer n, Integer m))  = Integer (n + m) 
  | fold (Op (AST_expr.Mult, Integer n, Integer m))  = Integer (n * m) 
  | fold (Op (AST_expr.Subt, Integer n, Integer m))  = Integer (n - m) 
  | fold (Op (AST_expr.GTEQ, Integer n, Integer m))  = if m <= n then (Boolean true) else (Boolean false)
  | fold (Op (AST_expr.EQ, Integer n, Integer m))  = if m = n then (Boolean true) else (Boolean false)  
  | fold (Op (bop, e1, e2))  = Op (bop, fold e1, fold e2)
  | fold (Assign(e1,e2))     = Assign(fold e1, fold e2)
  | fold (Seq(Skip, e))      = fold e 
  | fold (Seq(e,Skip))       = fold e 
  | fold (Seq(e1, e2))       = Seq(fold e1, fold e2)
  | fold (If(Boolean true, e, _)) = fold e 
  | fold (If(Boolean false, _, e)) = fold e 
  | fold (If(e1, e2, e3))    = If(fold e1, fold e2, fold e3)
  | fold (While(Boolean false, _)) = Skip 
  | fold (While(e1, e2))     = While(fold e1, fold e2)
  | fold (Deref (Ref e))     = fold e
  | fold (Deref e)           = Deref (fold e)
  | fold (Ref e)             = Ref(fold e)
  | fold (AppCls(f, el))     = AppCls(f, List.map fold el)
  | fold (AppDir(f, el))     = AppDir(f, List.map fold el)
  | fold (App(Direct f, el)) = AppDir(f, List.map fold el)
  | fold (App(e, el))        = App(fold e, List.map fold el)
  | fold (Let(x, e1, e2))    = Let(x, fold e1, fold e2)
  | fold e = e 

fun fold_fun (FunDecl(f, is_rec, x, fv, e)) = FunDecl(f, is_rec, x, fv, fold e)

fun fold_prog (Prog(fun_list, e)) = Prog(List.map fold_fun fun_list, e)

fun beta_find env v = (Library.lookup(env, v)) handle _ => Var v 

fun is_value Skip       = true 
 | is_value (Integer _) = true 
 | is_value (Boolean _) = true 
 | is_value (Var _)     = true 
 | is_value (Direct _)  = true 
 | is_value (Closure _) = true 
 | is_value _           = false 

fun beta env (UnaryOp (uop, e))  = UnaryOp (uop, beta env e)
  | beta env (Op (bop, e1, e2))  = Op (bop, beta env e1, beta env e2)
  | beta env (Assign(e1,e2))     = Assign(beta env e1, beta env e2)
  | beta env (Seq(e1, e2))       = Seq(beta env e1, beta env e2)
  | beta env (If(e1, e2, e3))    = If(beta env e1, beta env e2, beta env e3)
  | beta env (While(e1, e2))     = While(beta env e1, beta env e2)
  | beta env (Deref e)           = Deref (beta env e)
  | beta env (Ref e)             = Ref(beta env e)
  | beta env (AppCls(f, el))     = AppCls(f, List.map (beta env) el)
  | beta env (AppDir(f, el))     = AppDir(f, List.map (beta env) el)
  | beta env (Closure(f, el))    = Closure(f, List.map (beta env) el)
  | beta env (App(Var "print", el))  = AppDir("print", List.map (beta env) el)
  | beta env (App(e, el))        = App(beta env e, List.map (beta env) el)
  | beta env (Let(x, e1, e2))    = if is_value e1 then beta (Library.update (env, x, e1)) e2 else Let(x, beta env e1, beta env e2)
  | beta env (Var x)             = beta_find env x 
  | beta env e                   = e 

fun beta_fun (FunDecl(f, is_rec, x, fv, e)) = FunDecl(f, is_rec, x, fv, beta Library.empty_env e)

fun beta_prog (Prog(fun_list, e)) = Prog(List.map beta_fun fun_list, e)

fun inline_find env v = SOME(Library.lookup(env, v)) handle _ => NONE 

fun unfold_lets [] [] e = e 
  | unfold_lets (x::vars) (v::values) e = Let(x, v, unfold_lets vars values e)
  | unfold_lets _ _ _ = Library.internal_error "Closure.unfold_lets : mismatch in number of vars and values" 

fun inline env (UnaryOp (uop, e))  = UnaryOp (uop, inline env e)
  | inline env (Op (bop, e1, e2))  = Op (bop, inline env e1, inline env e2)
  | inline env (Assign(e1,e2))     = Assign(inline env e1, inline env e2)
  | inline env (Seq(e1, e2))       = Seq(inline env e1, inline env e2)
  | inline env (If(e1, e2, e3))    = If(inline env e1, inline env e2, inline env e3)
  | inline env (While(e1, e2))     = While(inline env e1, inline env e2)
  | inline env (Deref e)           = Deref (inline env e)
  | inline env (Ref e)             = Ref(inline env e)
  | inline env (AppDir(f, [e1]))   = 
    (case inline_find env f of 
       SOME(x, [], e2) => Let(x, e1, e2) 
     | _ => AppCls(f, [inline env e1]))
  | inline env (AppDir(f, el))     = AppDir(f, List.map (inline env) el)
  | inline env (App(Closure(f, el), [e1])) = 
    let val el' = List.map (inline env) el 
        val e1' = inline env e1 
    in 
       case inline_find env f of 
          SOME(x, fv, e2) => Let(x, e1', unfold_lets fv el' (inline env e2))
       | _ => App(Closure(f, el'), [e1'])
    end 
       
  | inline env (App(e, el))        = App(inline env e, List.map (inline env) el)
  | inline env (Let(x, e1, e2))    = Let(x, inline env e1, inline env e2)
  | inline env e                   = e 

fun mk_inline_env env [] = env
  | mk_inline_env env ((FunDecl(f, false, x, fv, e))::rest) = 
      mk_inline_env (Library.update (env, f, (x, fv, e))) rest 
  | mk_inline_env env (_ :: rest) = mk_inline_env env rest 

fun inline_fun env (FunDecl(f, is_rec, x, fv, e)) = FunDecl(f, is_rec, x, fv, inline env e)

fun inline_prog (Prog(fun_list, e)) = 
   let val env = mk_inline_env Library.empty_env fun_list
   in Prog(List.map (inline_fun env) fun_list, e) end 

fun simplify prog = fold_prog (inline_prog (beta_prog prog)) 









