(* MODIFIED by Conrad Watt *)

open Ast
open Source
open Types


(* Errors *)

module Invalid = Error.Make ()
exception Invalid = Invalid.Error

let error = Invalid.error
let require b at s = if not b then error at s


(* Context *)

type context =
{
  module_ : module_;
  types : func_type list;
  funcs : func_type list;
  tables : table_type list;
  memories : memory_type list;
  globals : global_type list;
  locals : value_type list;
  results : value_type list;
  labels : stack_type list;
}

let context m =
  { module_ = m; types = []; funcs = []; tables = []; memories = [];
    globals = []; locals = []; results = []; labels = [] }

let lookup category list x =
  try Lib.List32.nth list x.it with Failure _ ->
    error x.at ("unknown " ^ category ^ " " ^ Int32.to_string x.it)

let type_ (c : context) x = lookup "type" c.types x
let func (c : context) x = lookup "function" c.funcs x
let global (c : context) x = lookup "global" c.globals x
let table (c : context) x = lookup "table" c.tables x
let memory (c : context) x = lookup "memory" c.memories x

let check_arity n at =
  require (n <= 1) at "invalid result arity, larger than 1 is not (yet) allowed"

let check_block (c : context) (es : instr list) (ts : stack_type) at =
	let {module_; types; funcs; tables; memories; globals; locals; results; labels} = c in
	let tab = (match tables with
	           | [] -> None
						 | t :: ts -> Some (InterpreterAux.Arith.zero_nat)) in
	let mem = (match memories with
	           | [] -> None
						 | m :: ms -> Some (InterpreterAux.Arith.zero_nat)) in
	let cont = InterpreterAux.Wasm.T_context_ext ((List.map Ast_convert.convert_ftype types),
	                                              (List.map Ast_convert.convert_ftype funcs),
																								(List.map Ast_convert.convert_tg globals),
																								(tab),
																								(mem),
																								(Ast_convert.convert_vltype locals),
																								(List.map Ast_convert.convert_vltype labels),
																								Some (Ast_convert.convert_vltype results), ()) in
  let b = Checker.Wasm_Checker_Printing.typing cont (Ast_convert.convert_instrs es) (InterpreterAux.Wasm.Tf ([], (Ast_convert.convert_vltype ts))) in
  require b at
    ("type mismatch")


(* Functions & Constants *)

(*
 * Conventions:
 *   c : context
 *   m : module_
 *   f : func
 *   e : instr
 *   v : value
 *   t : value_type
 *   s : func_type
 *   x : variable
 *)

let check_type (t : type_) =
  let FuncType (ins, out) = t.it in
  check_arity (List.length out) t.at

let check_func (c : context) (f : func) =
  let {ftype; locals; body} = f.it in
  let FuncType (ins, out) = type_ c ftype in
  let c' = {c with locals = ins @ locals; results = out; labels = [out]} in
  check_block c' body out f.at


let is_const (c : context) (e : instr) =
  match e.it with
  | Const _ -> true
  | GetGlobal x -> let GlobalType (_, mut) = global c x in mut = Immutable
  | _ -> false

let check_const (c : context) (const : const) (t : value_type) =
  require (List.for_all (is_const c) const.it) const.at
    "constant expression required";
  check_block c const.it [t] const.at


(* Tables, Memories, & Globals *)

let check_table_type (t : table_type) at =
  let TableType ({min; max}, _) = t in
  match max with
  | None -> ()
  | Some max ->
    require (I32.le_u min max) at
      "table size minimum must not be greater than maximum"

let check_table (c : context) (tab : table) =
  let {ttype} = tab.it in
  check_table_type ttype tab.at

let check_memory_type (t : memory_type) at =
  let MemoryType {min; max} = t in
  require (I32.le_u min 65536l) at
    "memory size must be at most 65536 pages (4GiB)";
  match max with
  | None -> ()
  | Some max ->
    require (I32.le_u max 65536l) at
      "memory size must be at most 65536 pages (4GiB)";
    require (I32.le_u min max) at
      "memory size minimum must not be greater than maximum"

let check_memory (c : context) (mem : memory) =
  let {mtype} = mem.it in
  check_memory_type mtype mem.at

let check_elem (c : context) (seg : table_segment) =
  let {index; offset; init} = seg.it in
  check_const c offset I32Type;
  ignore (table c index);
  ignore (List.map (func c) init)

let check_data (c : context) (seg : memory_segment) =
  let {index; offset; init} = seg.it in
  check_const c offset I32Type;
  ignore (memory c index)

let check_global (c : context) (glob : global) =
  let {gtype; value} = glob.it in
  let GlobalType (t, mut) = gtype in
  check_const c value t


(* Modules *)

let check_start (c : context) (start : var option) =
  Lib.Option.app (fun x ->
    require (func c x = FuncType ([], [])) x.at
      "start function must not have parameters or results"
  ) start

let check_import (im : import) (c : context) : context =
  let {module_name = _; item_name = _; idesc} = im.it in
  match idesc.it with
  | FuncImport x ->
    {c with funcs = type_ c x :: c.funcs}
  | TableImport t ->
    check_table_type t idesc.at; {c with tables = t :: c.tables}
  | MemoryImport t ->
    check_memory_type t idesc.at; {c with memories = t :: c.memories}
  | GlobalImport t ->
    let GlobalType (_, mut) = t in
    require (mut = Immutable) idesc.at
      "mutable globals cannot be imported (yet)";
    {c with globals = t :: c.globals}

module NameSet = Set.Make(struct type t = Ast.name let compare = compare end)

let check_export (c : context) (set : NameSet.t) (ex : export) : NameSet.t =
  let {name; edesc} = ex.it in
  (match edesc.it with
  | FuncExport x -> ignore (func c x)
  | TableExport x -> ignore (table c x)
  | MemoryExport x -> ignore (memory c x)
  | GlobalExport x ->
    let GlobalType (_, mut) = global c x in
    require (mut = Immutable) edesc.at
      "mutable globals cannot be exported (yet)"
  );
  require (not (NameSet.mem name set)) ex.at "duplicate export name";
  NameSet.add name set

let check_module (m : module_) =
  let
    { types; imports; tables; memories; globals; funcs; start; elems; data;
      exports } = m.it
  in
  let c0 =
    List.fold_right check_import imports
      {(context m) with types = List.map (fun ty -> ty.it) types}
  in
  let c1 =
    { c0 with
      funcs = c0.funcs @ List.map (fun f -> type_ c0 f.it.ftype) funcs;
      tables = c0.tables @ List.map (fun tab -> tab.it.ttype) tables;
      memories = c0.memories @ List.map (fun mem -> mem.it.mtype) memories;
    }
  in
  let c =
    { c1 with globals = c1.globals @ List.map (fun g -> g.it.gtype) globals }
  in
  List.iter check_type types;
  List.iter (check_global c1) globals;
  List.iter (check_table c1) tables;
  List.iter (check_memory c1) memories;
  List.iter (check_elem c1) elems;
  List.iter (check_data c1) data;
  List.iter (check_func c) funcs;
  check_start c start;
  ignore (List.fold_left (check_export c) NameSet.empty exports);
  require (List.length c.tables <= 1) m.at
    "multiple tables are not allowed (yet)";
  require (List.length c.memories <= 1) m.at
    "multiple memories are not allowed (yet)"
