open InterpreterAux.Wasm
open Types
open Instance
open Values
open Source

let ocaml_int_to_nat n = InterpreterAux.Arith.nat_of_integer (Big_int.big_int_of_int n)

let ocaml_int32_to_nat n = InterpreterAux.Arith.nat_of_integer (Big_int.big_int_of_string (I32.to_string_u n))

let var_to_nat n = ocaml_int32_to_nat (n.Source.it)

let string_of_nat n = Big_int.string_of_big_int (InterpreterAux.Arith.integer_of_nat n)

let ocaml_lookup category list x =
  try Lib.List32.nth list x.it with Failure _ ->
    failwith "ill-formed"

let ocaml_func_type_of inst ftype = (ocaml_lookup "type" (!inst).module_.it.Ast.types ftype).it

let default_value = function
  | I32Type -> ConstInt32 I32Wrapper.zero
  | I64Type -> ConstInt64 I64Wrapper.zero
  | F32Type -> ConstFloat32 F32Wrapper.zero
  | F64Type -> ConstFloat64 F64Wrapper.zero

let convert_t = function
  | I32Type -> T_i32
	| F32Type -> T_f32
  | I64Type -> T_i64
	| F64Type -> T_f64

let convert_tg = function
	| GlobalType (t, Immutable) -> Tg_ext (T_immut, (convert_t t), ())
  | GlobalType (t, Mutable) -> Tg_ext (T_mut, (convert_t t), ())


let convert_value = function
	| I32 c -> ConstInt32 c
	| I64 c -> ConstInt64 c
	| F32 c -> ConstFloat32 c
	| F64 c -> ConstFloat64 c

let unconvert_value = function
	| ConstInt32 c -> I32 c
	| ConstInt64 c -> I64 c
	| ConstFloat32 c -> F32 c
	| ConstFloat64 c -> F64 c

let convert_int_testop = function
	| Ast.IntOp.Eqz -> Eqz

let convert_testop = function
	| I32 op -> Testop (T_i32, convert_int_testop op)
	| I64 op -> Testop (T_i64, convert_int_testop op)
	| _  -> failwith "ill-formed"

let convert_int_compareop = function
  | Ast.IntOp.Eq -> Eq
	| Ast.IntOp.Ne -> Ne
	| Ast.IntOp.LtS -> Lt S
  | Ast.IntOp.LtU -> Lt U
	| Ast.IntOp.GtS -> Gt S
	| Ast.IntOp.GtU -> Gt U
	| Ast.IntOp.LeS -> Le S
	| Ast.IntOp.LeU -> Le U
	| Ast.IntOp.GeS -> Ge S
	| Ast.IntOp.GeU -> Ge U

let convert_float_compareop = function
	| Ast.FloatOp.Eq -> Eqf
	| Ast.FloatOp.Ne -> Nef
	| Ast.FloatOp.Lt -> Ltf
	| Ast.FloatOp.Gt -> Gtf
	| Ast.FloatOp.Le -> Lef
	| Ast.FloatOp.Ge -> Gef

let convert_compareop = function
	| I32 op -> Relop_i (T_i32, convert_int_compareop op)
	| I64 op -> Relop_i (T_i64, convert_int_compareop op)
	| F32 op -> Relop_f (T_f32, convert_float_compareop op)
	| F64 op -> Relop_f (T_f64, convert_float_compareop op)

let convert_int_unop = function
	| Ast.IntOp.Clz -> Clz
	| Ast.IntOp.Ctz -> Ctz
	| Ast.IntOp.Popcnt -> Popcnt

let convert_float_unop = function
	| Ast.FloatOp.Neg -> Neg
	| Ast.FloatOp.Abs -> Abs
	| Ast.FloatOp.Ceil -> Ceil
	| Ast.FloatOp.Floor -> Floor
	| Ast.FloatOp.Trunc -> Trunc
	| Ast.FloatOp.Nearest -> Nearest
	| Ast.FloatOp.Sqrt -> Sqrt

let convert_unop = function
	| I32 op -> Unop_i (T_i32, convert_int_unop op)
	| I64 op -> Unop_i (T_i64, convert_int_unop op)
	| F32 op -> Unop_f (T_f32, convert_float_unop op)
	| F64 op  -> Unop_f (T_f64, convert_float_unop op)

let convert_int_binop = function
	| Ast.IntOp.Add -> Add
	| Ast.IntOp.Sub -> Sub
  | Ast.IntOp.Mul -> Mul
  | Ast.IntOp.DivS -> Div S
	| Ast.IntOp.DivU -> Div U
	| Ast.IntOp.RemS -> Rem S
	| Ast.IntOp.RemU -> Rem U
  | Ast.IntOp.And -> And
	| Ast.IntOp.Or -> Or
	| Ast.IntOp.Xor -> Xor
	| Ast.IntOp.Shl -> Shl
	| Ast.IntOp.ShrS -> Shr S
	| Ast.IntOp.ShrU -> Shr U
	| Ast.IntOp.Rotl -> Rotl
	| Ast.IntOp.Rotr -> Rotr

let convert_float_binop = function
	| Ast.FloatOp.Add -> Addf
	| Ast.FloatOp.Sub -> Subf
	| Ast.FloatOp.Mul -> Mulf
	| Ast.FloatOp.Div -> Divf
	| Ast.FloatOp.Min -> Min
	| Ast.FloatOp.Max -> Max
	| Ast.FloatOp.CopySign -> Copysign

let convert_binop = function
	| I32 op -> Binop_i (T_i32, convert_int_binop op)
	| I64 op -> Binop_i (T_i64, convert_int_binop op)
	| F32 op -> Binop_f (T_f32, convert_float_binop op)
	| F64 op  -> Binop_f (T_f64, convert_float_binop op)

let t_reinterpret = function
	| T_i32 -> T_f32
	| T_i64 -> T_f64
	| T_f32 -> T_i32
	| T_f64 -> T_i64

let convert_int_convertop t1 = function
  | Ast.IntOp.ExtendSI32 -> Cvtop (t1, Convert, T_i32, Some S)
	| Ast.IntOp.ExtendUI32 -> Cvtop (t1, Convert, T_i32, Some U)
	| Ast.IntOp.WrapI64 -> Cvtop (t1, Convert, T_i64, None)
  | Ast.IntOp.TruncSF32 -> Cvtop (t1, Convert, T_f32, Some S)
	| Ast.IntOp.TruncUF32 -> Cvtop (t1, Convert, T_f32, Some U)
	| Ast.IntOp.TruncSF64 -> Cvtop (t1, Convert, T_f64, Some S)
	| Ast.IntOp.TruncUF64 -> Cvtop (t1, Convert, T_f64, Some U)
  | Ast.IntOp.ReinterpretFloat -> Cvtop (t1, Reinterpret, t_reinterpret t1, None)

let convert_float_convertop t1 = function
  | Ast.FloatOp.ConvertSI32 -> Cvtop (t1, Convert, T_i32, Some S)
  | Ast.FloatOp.ConvertUI32 -> Cvtop (t1, Convert, T_i32, Some U)
  | Ast.FloatOp.ConvertSI64 -> Cvtop (t1, Convert, T_i64, Some S)
	| Ast.FloatOp.ConvertUI64 -> Cvtop (t1, Convert, T_i64, Some U)
  | Ast.FloatOp.PromoteF32 -> Cvtop (t1, Convert, T_f32, None)
	| Ast.FloatOp.DemoteF64 -> Cvtop (t1, Convert, T_f64, None)
  | Ast.FloatOp.ReinterpretInt -> Cvtop (t1, Reinterpret, t_reinterpret t1, None)

let convert_convertop = function
	| I32 op -> convert_int_convertop T_i32 op
	| I64 op -> convert_int_convertop T_i64 op
	| F32 op -> convert_float_convertop T_f32 op
	| F64 op  -> convert_float_convertop T_f64 op

let convert_glob g = Global_ext (T_mut, convert_value (!g), ())

let convert_vltype vl_type = List.map convert_t vl_type

let convert_ftype = function
	| FuncType (stype1, stype2) -> Tf (convert_vltype stype1, convert_vltype stype2)

let rec list_to_ne_list ns n =
	match ns with
	| [] -> Base (var_to_nat n)
	| n'::ns' -> Conz (var_to_nat n', list_to_ne_list ns' n)

let convert_tp = function
	| Memory.Mem8 -> Tp_i8
	| Memory.Mem16 -> Tp_i16
	| Memory.Mem32 -> Tp_i32

let convert_sx = function
	| Memory.SX -> S
	| Memory.ZX -> U

let convert_load_tp_sx = function
	| None -> None
	| Some (mtp, msx) -> Some (convert_tp mtp, convert_sx msx)

let convert_store_tp = function
	| None -> None
	| Some mtp -> Some (convert_tp mtp)

let rec convert_instr instr =
	match instr.it with
	| Ast.Unreachable -> Unreachable
	| Ast.Nop -> Nop
	| Ast.Block (st, binstrs) -> Block (Tf ([],convert_vltype st), convert_instrs binstrs)
	| Ast.Loop (st, binstrs) -> Loop (Tf ([],convert_vltype st), convert_instrs binstrs)
	| Ast.If (st, binstrs1, binstrs2) -> If (Tf ([],convert_vltype st), convert_instrs binstrs1, convert_instrs binstrs2)
	| Ast.Br n -> Br (var_to_nat n)
 	| Ast.BrIf n -> Br_if (var_to_nat n)
	| Ast.BrTable (ns, n) -> Br_table (list_to_ne_list ns n)
	| Ast.Return -> Return
	| Ast.Call n -> Call (var_to_nat n)
	| Ast.CallIndirect n -> Call_indirect (var_to_nat n)
	| Ast.Drop -> Drop
	| Ast.Select -> Select
	| Ast.GetLocal n -> Get_local (var_to_nat n)
	| Ast.SetLocal n -> Set_local (var_to_nat n)
	| Ast.TeeLocal n -> Tee_local (var_to_nat n)
	| Ast.GetGlobal n -> Get_global (var_to_nat n)
	| Ast.SetGlobal n -> Set_global (var_to_nat n)
	| Ast.Load lop -> let {Ast.ty; Ast.align; Ast.offset; Ast.sz} = lop in
	                  Load ((convert_t ty), convert_load_tp_sx sz, (ocaml_int_to_nat align), (ocaml_int32_to_nat offset))
	| Ast.Store sop -> let {Ast.ty; Ast.align; Ast.offset; Ast.sz} = sop in
	                   Store ((convert_t ty), convert_store_tp sz, (ocaml_int_to_nat align), (ocaml_int32_to_nat offset))
	| Ast.CurrentMemory -> Current_memory
	| Ast.GrowMemory -> Grow_memory
	| Ast.Const v -> EConst (convert_value v.it)
	| Ast.Test top -> convert_testop top
	| Ast.Compare cop -> convert_compareop cop
	| Ast.Unary uop -> convert_unop uop
	| Ast.Binary bop -> convert_binop bop
	| Ast.Convert cop -> convert_convertop cop
and convert_instrs instrs = List.map convert_instr instrs

let rec get_table' t n =
	if (n = 0l)
	then []
  else
  	match Table.load t (Int32.sub (Table.size t) n) with
    | Func f -> Some f :: get_table' t (Int32.sub n 1l)
    | _ -> None :: get_table' t (Int32.sub n 1l)

let get_table t = get_table' t (Table.size t)

let memrq s ss = List.exists (fun s' -> !s == !s') ss

let index_some ss s =
	match Lib.List.index_where (fun s' -> !s == !s') ss with
	| None -> failwith "error in gather procedure"
	| Some n -> n

let rec collect_instances_cl (ss : (instance ref) list) (cl :closure) : ((instance ref) list) =
	match cl with
	| AstFunc (s, func) -> collect_instances_instance ss s
	| HostFunc (_, _) -> ss
and collect_instances_funcs (ss : (instance ref) list) (funcs : closure list) : ((instance ref) list) =
	List.fold_left collect_instances_cl ss funcs
and collect_instances_table (ss : (instance ref) list) (table: Table.t) : ((instance ref) list) =
	List.fold_left (fun ss t -> Lib.Option.get (Lib.Option.map (collect_instances_cl ss) t) ss) ss (get_table table)
and collect_instances_tables (ss : (instance ref) list) (tables: Table.t list) : ((instance ref) list) =
	List.fold_left collect_instances_table ss tables
and collect_instances_instance (ss : (instance ref) list) (s : instance ref) : ((instance ref) list) =
	if (memrq s ss) then ss else
  let ss' = ss @ [s] in
  let { module_; funcs; tables; memories; globals; exports} = !s in
	let ss'' = collect_instances_funcs ss' funcs in
	let ss''' = collect_instances_tables ss'' tables in
	ss'''

let collect_globals_instance (gs : global list) (s : (instance ref)) : (global list) =
	List.fold_left (fun gs g -> if memrq g gs then gs else gs @ [g]) gs ((!s).globals)

let collect_globals (ss : (instance ref) list) : (global list) =
 List.fold_left collect_globals_instance [] ss

let convert_func (ss : (instance ref) list) (s : instance ref) (f : Ast.func) =
	let { Ast.ftype; Ast.locals; Ast.body; } = f.it in
	let n = index_some ss s in
	(n, convert_ftype (ocaml_func_type_of s ftype), convert_vltype locals, convert_instrs body)

let convert_cl (ss : (instance ref) list) = function
	| AstFunc (s, func) -> let (inst, ftype, vltype, instrs) = (convert_func ss s func) in
	                          Func_native (ocaml_int_to_nat inst, ftype, vltype, instrs)
	| HostFunc (f_type, f) -> Func_host (convert_ftype f_type, f)

let convert_cls (ss : (instance ref) list) cls = List.map (convert_cl ss) cls

let convert_table (ss : (instance ref) list) t = List.map (Lib.Option.map (convert_cl ss)) t

(*
let convert_instance (ss : (instance ref) list) (s : instance ref) : (I32Wrapper.t, I64Wrapper.t, F32Wrapper.t, F64Wrapper.t, ImplWrapper.memory, ImplWrapper.host_function_t, unit) s_ext =
	let { module_; funcs; tables; memories; globals; exports} = !s in
	let types = List.map (fun (t : (Types.func_type Source.phrase)) -> convert_ftype (t.it)) (module_.it.Ast.types) in
  let (instance_list, cls) = convert_cls [s] funcs in
	let globs = List.map convert_glob globals in
	let (instance_list', tables) = get_tables instance_list tables in
	let tab = Some (InterpreterAux.Arith.zero_nat) in
	let mem = Some (InterpreterAux.Arith.zero_nat) in
	let insts = [Inst_ext (types, cls, globs, tab, mem, ())] in
	S_ext (insts, tables, memories, ()) *)

let convert_instance (ss : (instance ref) list) (insts, fs, ts, ms, gs) (s : instance ref) =
	let { module_; funcs; tables; memories; globals; exports} = !s in
  let types = List.map (fun (t : (Types.func_type Source.phrase)) -> convert_ftype (t.it)) (module_.it.Ast.types) in
	let fs' = convert_cls ss funcs in
	let t' = (match tables with
					  | table::tables' -> [convert_table ss (get_table table)]
					  | [] -> []) in
	let m' = (match memories with
	          | memory::memories' -> [memory]
					  | [] -> []) in
	let f_inds = List.mapi (fun i f -> ocaml_int_to_nat (i + (List.length fs))) fs' in
	let t_ind = ocaml_int_to_nat (List.length ts) in
	let m_ind = ocaml_int_to_nat (List.length ms) in
	let g_inds = List.map (fun g -> ocaml_int_to_nat (index_some gs g)) globals in
	let inst' = Inst_ext (types, f_inds, Some t_ind, Some m_ind, g_inds, ()) in
	(insts@[inst'], fs@fs', ts@t', ms@m', gs)

let create_store (s : (instance ref)) : (instance ref) list * global list * (I32Wrapper.t, I64Wrapper.t, F32Wrapper.t, F64Wrapper.t, ImplWrapperTypes.memory, ImplWrapperTypes.host_function_t, unit) s_ext =
	let ss = collect_instances_instance [] s in
	let gs = collect_globals ss in
	let (insts, funcs, tables, memories, gs') = List.fold_left (convert_instance ss) ([], [], [], [], gs) ss in
	let globals = List.map convert_glob gs' in
	(ss, gs, S_ext (insts, funcs, tables, memories, globals, ()))

let convert_values_to_es vs = (List.map (fun v -> Basic (EConst (convert_value v))) vs)

let convert_b_es_to_es b_es = (List.map (fun b_e -> Basic b_e) b_es)

let empty_config (ss : (instance ref) list) (vs : value list) (cl : closure) =
	([], (convert_values_to_es vs) @ [Callcl (convert_cl ss cl)])
