open SMLofNJ;
open Cont;

exception CoErr of string;

(* CoVal is the type of value passed from one coroutine to another *)
datatype CoVal = ValInt of int
               | ValCo  of CoVal cont ref *  (* The Continuation *)
                           CoVal ref         (* The Parent *)
               | ValNull


fun wrs x = print x;

fun wri x = print(Int.toString x);

fun wrv (ValInt  x) = wri x
|   wrv (ValCo   x) = wrs "ValCo"
|   wrv  ValNull    = wrs "ValNull";

fun nl() = wrs "\n";

val res = ref ValNull;

(* currco is the current coroutine *)
val currco = ref (callcc(fn me => throw me (ValCo(ref me, ref ValNull))));


fun getcontin (ValCo(ref con, _)) = con
|   getcontin _  = raise CoErr("getcontin");

fun setcontin ((ValCo(res, _)), x) = res := x
|   setcontin _  = raise CoErr("setcontin");

fun getparent (ValCo(_, ref res)) = res
|   getparent _  = raise CoErr("getparent");

fun setparent ((ValCo(_, res)), x) = res := x
|   setparent _  = raise CoErr("setparent");

fun changeco(co, x) =
  callcc(fn me =>
          ( setcontin(!currco, me); (* save our continuation *)
            currco := co;
            throw (getcontin co) x;    (* give control to co *)
            x
          )
        );

fun callco(co, x) =
( setparent(co, !currco);
  changeco(co, x)
);

fun resumeco(co, x) =
( let val p = getparent(!currco)
  in setparent(!currco, ValNull);
     setparent(co, p);
     changeco(co, x)
  end
);

fun cowait x =
  let val co = getparent(!currco)
  in setparent(!currco, ValNull);
     changeco(co, x)
  end;

fun createco f = callcc(fn me =>
  let val c = ref (ValCo(ref me,          (* the continuation *)
                         ref (!currco)))  (* the parent *)
  in setcontin(!currco, me);
     currco := !c;                        (* Make it the current coroutine *)

     while true do c := f(cowait (!c));   (* The main coroutine loop *)
     !c   (* just to get the result type right *)
  end);

(* Test program *)

fun main() =
  let val a_co = ref ValNull
      and b_co = ref ValNull
      and c_co = ref ValNull
      and d_co = ref ValNull;

      fun a_fn (ValInt x) = 
        let val res = ref ValNull
        in wrs("a_co: entered with value "); wri x; nl();
           wrs("a_co: returning "); wri(x+10); nl();
           ValInt(x+10)
        end
      |   a_fn x = ( wrs "a_fn: Bad argument "; wrv x; nl(); x );

      fun b_fn (ValInt x) =
        let val res = ref ValNull
        in wrs("b_co: entered with value "); wri(x); nl();
           wrs("b_co: calling callco(a_co, 2000)\n");
           res := callco(!a_co, ValInt 2000);
           wrs("b_co: callco(a_co, 2000) => "); wrv(!res); nl();
           wrs("b_co: returning "); wri(x+20); nl();
           ValInt(x+20)
        end
      |   b_fn x = ( wrs "b_fn: Bad argument "; wrv x; nl(); x );

      fun c_fn (ValInt x) = 
        let val res = ref ValNull
        in wrs("c_co: entered with value "); wri x; nl();
           wrs("c_co: calling resumeco(a_co, 3000)\n");
           res := resumeco(!a_co,ValInt 3000);
           wrs("c_co: resumeco(a_co, 3000) => "); wrv(!res); nl();
           wrs("c_co: returning "); wri(x+30); nl();
           ValInt(x+30)
        end
      |   c_fn x = ( wrs "c_fn: Bad argument "; wrv x; nl(); x );

      fun d_fn (ValInt x) = 
        let val res = ref ValNull
        in wrs("d_co: entered with value "); wri x; nl();
           wrs("d_co: calling resumeco(d_co, 4000)\n");
           res := resumeco(!d_co,ValInt 4000);
           wrs("d_co: resumeco(d_co, 4000) => "); wrv(!res); nl();
           wrs("d_co: returning "); wri(x+40); nl();
           ValInt(x+40)
        end
      |   d_fn x = ( wrs "d_fn: Bad argument "; wrv x; nl(); x );

  in
     a_co := createco a_fn;
     b_co := createco b_fn;
     c_co := createco c_fn;
     d_co := createco d_fn;

     wrs("\nTest: root -> a -> root\n");
     wrs("root: calling callco(a_co,  100)\n");
     res := callco(!a_co, ValInt 100);
     wrs("root: callco(a_co,  100) => "); wrv(!res); nl();

     wrs("\nTest: root -> a -> root, ie check: c:=fn(cowait(c)) REPEAT\n");
     wrs("root: calling callco(a_co,  200)\n");
     res := callco(!a_co, ValInt 200);
     wrs("root: callco(a_co,  200) => "); wrv(!res); nl();

     wrs "\nTest: root -> b -> a -> b -> root\n";
     wrs "root: calling callco(b_co,  300)\n";
     res := callco(!b_co, ValInt 300);
     wrs "root: callco(a_co,  300) => "; wrv(!res); nl();

     wrs "\nTest: root -> b -> a -> b -> root  again\n";
     wrs "root: calling callco(b_co,  400)\n";
     res := callco(!b_co, ValInt 400);
     wrs "root: callco(b_co,  400) => "; wrv(!res); nl();

     wrs "\nTest: root -> c -> a -> root, ie check resumeco in c_co\n";
     wrs "root: calling callco(c_co,  500)\n";
     res := callco(!c_co, ValInt 500);
     wrs "root: callco(c_co,  500) => "; wrv(!res); nl();

     wrs("\nTest: root -> c -> root\n");
     wrs("root: calling callco(c_co,  600)\n");
     res := callco(!c_co, ValInt 600);
     wrs "root: callco(c_co,  600) => "; wrv(!res); nl();

     wrs("\nTest: root -> d -> d -> root, ie can resumeco call itself\n");
     wrs("root: calling callco(d_co,  700)\n");
     res := callco(!d_co, ValInt 700);
     wrs "root: callco(d_co,  700) => "; wrv(!res); nl();

     wrs "\nEnd of test\n";
     ValInt 0
  end;

main() handle CoErr x => (wrs "CoErr: "; wrs x; nl(); ValInt 999);

(* This program should generate the following output:

Test: root -> a -> root
root: calling callco(a_co,  100)
a_co: entered with value 100
a_co: returning 110
root: callco(a_co,  100) => 110

Test: root -> a -> root, ie check: c:=fn(cowait(c)) REPEAT
root: calling callco(a_co,  200)
a_co: entered with value 200
a_co: returning 210
root: callco(a_co,  200) => 210

Test: root -> b -> a -> b -> root
root: calling callco(b_co,  300)
b_co: entered with value 300
b_co: calling callco(a_co, 2000)
a_co: entered with value 2000
a_co: returning 2010
b_co: callco(a_co, 2000) => 2010
b_co: returning 320
root: callco(a_co,  300) => 320

Test: root -> b -> a -> b -> root  again
root: calling callco(b_co,  400)
b_co: entered with value 400
b_co: calling callco(a_co, 2000)
a_co: entered with value 2000
a_co: returning 2010
b_co: callco(a_co, 2000) => 2010
b_co: returning 420
root: callco(b_co,  400) => 420

Test: root -> c -> a -> root, ie check resumeco in c_co
root: calling callco(c_co,  500)
c_co: entered with value 500
c_co: calling resumeco(a_co, 3000)
a_co: entered with value 3000
a_co: returning 3010
root: callco(c_co,  500) => 3010

Test: root -> c -> root
root: calling callco(c_co,  600)
c_co: resumeco(a_co, 3000) => 600
c_co: returning 530
root: callco(c_co,  600) => 530

Test: root -> d -> d -> root, ie can resumeco call itself
root: calling callco(d_co,  700)
d_co: entered with value 700
d_co: calling resumeco(d_co, 4000)
d_co: resumeco(d_co, 4000) => 4000
d_co: returning 740
root: callco(d_co,  700) => 740

End of test
*)
