(*---------------------------------------------------------------------------*
 * General utilities for running sockets in MoscowML.                        *
 *---------------------------------------------------------------------------*)

structure SocketUtils :> SocketUtils = 
struct

open Utils;


(*---------------------------------------------------------------------------*
 * I'm not sure why, but I seem to have to call connect with a fresh socket  *
 * each time in the Internet domain. In other words, if a connect fails in   *
 * the Internet domain, sock seems to be "dirty", and can't be used in a     *
 * later attempt to connect.                                                 *
 *---------------------------------------------------------------------------*)
fun findInetServer ipno port
 :(Socket.pf_inet, Socket.active Socket.stream)Socket.socket
  = let val sock = Socket.inetStream ()
        val addr = Socket.inetAddr ipno port
                   handle e => (Socket.close sock; raise e)
    in
       Socket.connect(sock, addr) handle e => (Socket.close sock; raise e);
       sock
    end;

fun findUnixServer file 
 :(Socket.pf_file, Socket.active Socket.stream)Socket.socket
  = let val sock = Socket.fileStream ()
        val addr = Socket.fileAddr file 
                     handle e => (Socket.close sock; raise e)
    in
       Socket.connect(sock, addr) handle e => (Socket.close sock; raise e);
       sock
    end;

(*---------------------------------------------------------------------------*
 * Creates a CPE on the current machine, and returns a socket to be used     *
 * for communicating with the CPE.                                           *
 *---------------------------------------------------------------------------*)
local fun start_then_find comm (finder,rounds) es =
        if (Process.system comm = Process.success)
        then attempt finder rounds
             handle Fail "attempt" 
              => (stdOutput (String.concat 
                     ["Unable to find Core Proof Engine after ", 
                      Int.toString rounds, " probes.\n"])
                   ;
                  raise Fail "createInetServer")
        else raise Fail "Unable to start Core Proof Engine"
in
fun createInetServer {archOS,HOLDIR} server ipno port rounds =
  let val c1 = String.concat ["archOS=",archOS,"\n"]
      val c2 = String.concat ["HOLDIR=",HOLDIR,"\n"] 
      val portstr = Int.toString port
      val c3 = String.concat [server," ",ipno," ",portstr," > /dev/null &"]
  in
    start_then_find 
        (String.concat [c1,c2,c3])
        ((fn () => findInetServer ipno port), rounds)
        "createInetServer"
  end

fun createUnixServer {archOS,HOLDIR} server rounds =
  let val file = FileSys.tmpName()
      val c1 = String.concat ["archOS=",archOS,"\n"]
      val c2 = String.concat ["HOLDIR=",HOLDIR,"\n"] 
      val c3 = String.concat [server," ",file," > /dev/null &"]
  in
    (start_then_find 
        (String.concat [c1,c2,c3])
        ((fn () => findUnixServer file, rounds))
        "createUnixServer", file)
  end
end;


(*---------------------------------------------------------------------------*
 * A simple internet server, not currently used. It accepts requests from    *
 * any client, but each client must re-connect for each transaction it wants *
 * to make.                                                                  *
 *---------------------------------------------------------------------------*)
(*
fun InetServer1 f ipnum port = 
   let val serversock = Socket.inetStream ()
       val addr = Socket.inetAddr ipnum port
       val _    = Socket.bind(serversock, addr)
       val _    = Socket.listen(serversock, 5)
       val continue = ref true
   in 
      while !continue do
         let val (clientsock,_) = Socket.accept serversock
         in 
            f clientsock continue  
                before 
            Socket.close clientsock
         end;
      stdOutput ("inet_cpe : Closing serversock.\n");
      Socket.close serversock
   end;

(* example client *)

fun preAsk ipno port q = 
    let val sock = U.findServer ipno port
    in U.ask sock q
    end;

val ask = preAsk "127.0.0.1" 5436;

*)


(*---------------------------------------------------------------------------*
 * A serial monogamist server that repeatedly waits for a client and then    *
 * handles its requests until the client leaves. It takes a freshly created  *
 * socket and an address, created with Socket.{fileAddr,inetAddr}, and binds *
 * them, and then does a listen, and then goes into its server loop.         *
 * Differences between the Internet and Unix domains have been handled by    *
 * the time this function is called (in the creation of serversock and addr).*
 *---------------------------------------------------------------------------*)
fun serialMonogamist f (serversock,addr) = 
   let val _    = Socket.bind(serversock, addr)
       val _    = Socket.listen(serversock, 1)
       val continue = ref true
   in 
      while true 
      do let val (clientsock,_) = Socket.accept serversock
             val _ = continue := true
         in 
            while !continue do (f clientsock continue);
            Socket.close clientsock
         end
   end;


(*---------------------------------------------------------------------------*
 * Push a string into a socket.                                              *
 *---------------------------------------------------------------------------*)
fun sendString sock s = 
  (Socket.sendVec(sock, Byte.stringToBytes (sizedstring s)) ; ());



(*---------------------------------------------------------------------------*
 * Read a string from a socket. The length of the string is stuck at the     *
 * beginning of the string.
 *---------------------------------------------------------------------------*)
local val chunksize = 1024
      val numeric = Char.contains"0123456789"
in
fun readString sock =
 let fun grab togo A =
       let val str = Byte.bytesToString (Socket.recvVec (sock, chunksize))
           val strsize = String.size str
       in if strsize < togo
          then grab (togo - strsize) (str::A)
          else String.concat (List.rev (str::A))
       end
     val str0 = Byte.bytesToString (Socket.recvVec (sock, chunksize))
     val str0size = String.size str0
     val sstr0 = Substring.all str0
     val (x,y) = Substring.splitl numeric sstr0
     val howbig = Int.fromString (Substring.string x)
     val ystr = Substring.string (Substring.triml 1 y)
  in 
    case howbig
     of NONE => raise Fail "Badly formed message: not prefixed with size!"
      | SOME n =>
         (n, if str0size < n then grab (n - str0size) [ystr]
                             else ystr)
  end
end;


(*---------------------------------------------------------------------------*
 * For ML clients. Send a request on a socket, and wait for a reply.         *
 * Interpret the response as a member of the Utils.result type.              *
 *---------------------------------------------------------------------------*)

local val alpha = Char.contains"SomeNone"
      val space = Char.contains" "
in
fun sendRqt sock [QUOTE rqt] =
      let val _ = sendString sock rqt
          val (n,reply) = readString sock
          val (prefx, rst) = Substring.splitl alpha (Substring.all reply)
          val prefstring  = Substring.string prefx
          val replystring = Substring.string (#2 (Substring.splitl space rst))
      in
        case prefstring
         of "None" => None replystring
          | "Some" => Some replystring
          |    _   => raise Fail (String.concat 
                              ["Protocol violation: ", 
                               "expecting a result constructor ",
                               "(\"Some\" or \"None\")."])
      end
   | sendRqt _ _ = raise Fail "sendRqt <sock> ` ... `"
end;


(*---------------------------------------------------------------------------*
 * For interactive use.                                                      *
 *---------------------------------------------------------------------------*)
fun ask sock (q :unit frag list) = 
    case (sendRqt sock q)
     of None s => stdOutput (String.concat["Failed: \n", s, "\n"])
      | Some s => stdOutput (String.concat[s, "\n"]);


fun end_client sock () = 
  (ask sock `Stop`; Socket.close sock);

fun rel_client sock () = 
  (ask sock `Release`; Socket.close sock);

end;