val junk = 1;

(* 
This is a translation of the BCPL Cobench benchmark
into New Jersey SML (using continuations).

Translated by Martin Richards (c) 31 March 2004
*)

open SMLofNJ;
open StringCvt;
open CommandLine;
open Cont;

exception CoErr of string;

(* CoVal is the type of value passed from one coroutine to another *)
datatype CoVal = ValInt of int
               | ValCo  of CoVal cont ref *  (* The Continuation *)
                           CoVal ref         (* The Parent *)
               | ValRef of CoVal ref
               | ValNull


fun wrs x = print x;

fun wri x = print(Int.toString x);

fun wrv (ValInt  x) = wri x
|   wrv (ValCo   x) = wrs "ValCo"
|   wrv (ValRef  x) = wrs "ValRef"
|   wrv  ValNull    = wrs "ValNull";

fun nl() = wrs "\n";

val res = ref ValNull;

(* currco is the current coroutine *)
val currco = ref (callcc(fn me => throw me (ValCo(ref me, ref ValNull))));


fun getcontin (ValCo(ref con, _)) = con
|   getcontin _  = raise CoErr("getcontin");

fun setcontin ((ValCo(res, _)), x) = res := x
|   setcontin _  = raise CoErr("setcontin");

fun getparent (ValCo(_, ref res)) = res
|   getparent _  = raise CoErr("getparent");

fun setparent ((ValCo(_, res)), x) = res := x
|   setparent _  = raise CoErr("setparent");

fun getint (ValInt x) = x
|   getint _  = raise CoErr("getint");

fun getref (ValRef x) = x
|   getref _  = raise CoErr("getref");

fun changeco(co, x) =
  callcc(fn me =>
          ( setcontin(!currco, me); (* save our continuation *)
            currco := co;
            throw (getcontin co) x;    (* give control to co *)
            x
          )
        );

fun callco(co, x) =
( setparent(co, !currco);
  changeco(co, x)
);

fun resumeco(co, x) =
( let val p = getparent(!currco)
  in setparent(!currco, ValNull);
     setparent(co, p);
     changeco(co, x)
  end
);

fun cowait x =
  let val co = getparent(!currco)
  in setparent(!currco, ValNull);
     changeco(co, x)
  end;

fun createco f = callcc(fn me =>
  let val c = ref (ValCo(ref me,          (* the continuation *)
                         ref (!currco)))  (* the parent *)
  in setcontin(!currco, me);
     currco := !c;                        (* Make it the current coroutine *)

     while true do c := f(cowait (!c));   (* The main coroutine loop *)
     !c   (* just to get the result type right *)
  end);


(* Benchmark program *)

val tracing = ref false;
val k = ref 500;
val n = ref 10000;

fun coread channel =
  let val chan = getref channel
  in case !chan of
       ValCo(_, _) =>
         ( chan := ValNull;        (* Clear the channel word *)
           resumeco(!chan, !currco)
         )
     | _ =>
         ( chan := !currco;        (* Set channel word to this coroutine *)
           cowait ValNull          (* Wait for value from cowrite *)
         )
  end;


fun cowrite(channel, value) =
  let val chan = getref channel;
      val co = !chan
  in case co of
       ValCo(_, _) =>
         ( chan := ValNull;         (* Clear the channel word *)
           callco(co, value)     (* Send val to coread *)
         )
     | _ =>
         ( chan := !currco;         (* Set channel word to this coroutine *)
           callco(cowait ValNull, value)
         )
  end;



fun sourceiter(i, j, channel) =
  if i<j then (
                if !tracing
                then ( wrs "sourceco: sending number "; wri i; nl())
                else ();
                cowrite(channel, ValInt i);
                sourceiter(i+1, j, channel)
              )
         else ();

fun sourcefn nextco =
  let val k' = getint(cowait ValNull)
      and channel = ValRef(ref ValNull)

  in (
       if !tracing
       then ( wrs "sourceco: started with k = "; wri (!k); nl() )
       else ();
       callco(nextco, channel);
       sourceiter(1, !k, channel);  (* send k numbers to nextco *)
       if !tracing
       then wrs "sourceco: sending number 0\n"
       else ();
       cowrite(channel, ValInt 0);  (* send 0 to next co        *)
       if !tracing
       then wrs "sourceco: dying\n"
       else ();
       ValNull
     )
  end;

fun copyfn nextco =
  let val outchannel = ValRef(ref ValNull)
      and inchannel  = cowait ValNull
      and value = ref (~1)
  in (
       callco(nextco, outchannel);

       while !value <> 0 do
       ( value := getint(coread inchannel);
         if !tracing
         then ( wrs "copyco:   copying number "; wri (!value); nl() )
         else ();
         cowrite(outchannel, ValInt (!value))
       );

       if !tracing
       then wrs "copyco:   dying\n"
       else ();
       ValNull
     )
  end;

fun sinkfn in_channel =
  let val value = ref (~1)
  in while !value <> 0 do
     ( value := getint(coread in_channel);
       if !tracing
       then ( wrs "sinkco:   received value "; wri (!value); nl() )
       else ()
     );
     if !tracing
     then wrs "sinkco:   dying\n"
     else ();
     ValNull
  end;

fun mksinkco() = createco sinkfn

fun mkcopycos(0, cptr) = cptr
|   mkcopycos(n, cptr) =
       let val co = createco copyfn
       in callco(co, cptr);
          mkcopycos(n-1, co)
       end;

fun mksourceco cptr =
  let val co = createco sourcefn
  in ( callco(co, cptr); co )
  end;

fun cobench() =
  let val source_co = mksourceco( mkcopycos(!n, mksinkco()))

  in
     if !tracing
     then wrs "All coroutines created\n\n"
     else ();
     callco(source_co, ValInt (!k)); (* Tell sourceco to send k numbers *)
     wrs "\nCobench done\n"
  end;

fun currtime() = Time.now();

fun atoi s = let val x = Int.fromString s
             in case x of
                  (SOME k) => k
                | _        => 0
             end;

fun processargs [] = ()
  | processargs [s]  =
      if s="-t" then tracing := true else ()
  | processargs (s::arg::rest)  =
      if s="-t" then (tracing := true; processargs (arg::rest))
      else if s="-n" then (n :=atoi arg;  processargs rest)
      else if s="-k" then (k :=atoi arg;  processargs rest)
           else ();

fun rdargs() =
  let val com = arguments()
  in case com of
       (x::xs) => processargs xs
     | _       => ()
  end;

( k := 10000;
  n := 500;
  tracing := false;

  rdargs();

  wrs "\nCobench sending "; wri (!k);
  wrs " numbers via "; wri (!n);
  wrs " copy coroutines\n";

  let val t0 = currtime()
  in  cobench();
      wrs "\nTime taken is ";
      wrs (Time.toString(Time.-(currtime(),t0))); nl(); nl()
  end
  handle CoErr x => (wrs "CoErr: "; wrs x; wrs "\nCobench failed\n\n")
)