{--
Copyright (c) 2006, Peng Li
              2006, Stephan A. Zdancewic
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:

* Redistributions of source code must retain the above copyright
  notice, this list of conditions and the following disclaimer.

* Redistributions in binary form must reproduce the above copyright
  notice, this list of conditions and the following disclaimer in the
  documentation and/or other materials provided with the distribution.

* Neither the name of the copyright owners nor the names of its
  contributors may be used to endorse or promote products derived from
  this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
--}

module SockIO where
import Thread
import Data.List
import Control.Monad

type SockFD = EpollFD

{-----------------------------------------------------------------------
Synchronous Socket I/O operations

These functions are designed to provide a higher-level interface to
the low-level system calls and nonblocking IO operations.  The
programming model is synchronous; errors are handled using exceptions.

The send/recv/close operations also work with pipes.
-----------------------------------------------------------------------}


{-----------------------------------------------------------------------
Create a listening socket on a port

Exception:
  ERROR_SOCK_LISTEN:port
-----------------------------------------------------------------------}
sock_listen :: Int -> CPSMonad SockFD
sock_listen pt = 
  do sock <- sys_nbio $ nbio_listen pt
     when (sock < 0) $ sys_throw ("ERROR_SOCK_LISTEN:=" ++(show pt))
     return sock

{-----------------------------------------------------------------------
Accept a socket connection

Exceptions:
  ERROR_SOCK_ACCEPT
-----------------------------------------------------------------------}
sock_accept :: SockFD -> CPSMonad SockFD
sock_accept sock = do
   result <- sys_nbio $ nbio_accept sock
   case result of
      Nothing -> do 
        sys_epoll_wait sock EPOLL_READ
        sock_accept sock
      Just client -> do
        when (client < 0) $ sys_throw "ERROR_SOCK_ACCEPT"
        return client


{-----------------------------------------------------------------------
Make an active connection 

Exceptions:
  ERROR_SOCK_CONNECT
-----------------------------------------------------------------------}
sock_connect :: String -> Int -> CPSMonad SockFD
sock_connect hostname port = do
   ipaddr <- sys_blio $ dns_lookup hostname
   when (ipaddr == 0) $ sys_throw "ERROR_SOCK_CONNECT:dns lookup failure"
   client <- sys_nbio $ nbio_connect ipaddr port
   when (client < 0) $ sys_throw "ERROR_SOCK_CONNECT"
   sys_epoll_wait client EPOLL_WRITE
   return client

{-----------------------------------------------------------------------
Close a socket
-----------------------------------------------------------------------}
sock_close :: SockFD -> CPSMonad Int
sock_close fd = sys_nbio $ nbio_close fd


{-----------------------------------------------------------------------
Receive data from a socket. Return as soon as 
  (1) some data is available, or (2) EOF is reached.

Returns: A new Chunk object pointing to the received data.
         On EOF, the new Chunk object has a zero length.

Exceptions:
  ERROR_SOCK_READ
-----------------------------------------------------------------------}
sock_recv_any :: SockFD -> CPSMonad Chunk
sock_recv_any sock =
  do chk@(Chunk fptr off len) <- sys_nbio $ new_chunk 1024
     sock_recv' sock chk
 where sock_recv' sock chk@(Chunk fptr off len) = do
        result <- sys_nbio $ nbio_read sock chk
        case result of
          Nothing -> do sys_epoll_wait sock EPOLL_READ
                        sock_recv' sock chk
          Just 0         -> return $ Chunk fptr off 0
          Just (-1)      -> do sys_throw "ERROR_SOCK_READ"
                               return $ Chunk fptr off 0
          Just (numread) -> return $ Chunk fptr off numread
   

{-----------------------------------------------------------------------
Send data over a socket.  Keep sending until all data in the buffer
are sent.

Exceptions:
  ERROR_SOCK_WRITE
-----------------------------------------------------------------------}
sock_send_all :: SockFD -> Chunk -> CPSMonad ()
sock_send_all sock chk@(Chunk fptr off len) = 
  do result <- sys_nbio $ nbio_write sock chk
     case result of
       Nothing -> do sys_epoll_wait sock EPOLL_WRITE
                     sock_send_all sock chk
       Just (-1) -> sys_throw "ERROR_SOCK_WRITE"
       Just numwritten -> 
           if (numwritten == len)
              then return ()
              else sock_send_all sock (Chunk fptr (off+numwritten) (len-numwritten))
  

{-----------------------------------------------------------------------
Receive data from a socket. Keep receiving until the buffer is full.

Exceptions:
  EOF
  ERROR_SOCK_READ
-----------------------------------------------------------------------}
sock_recv_all :: SockFD -> Chunk -> CPSMonad ()
sock_recv_all sock chk@(Chunk fptr off len) =
  do result <- sys_nbio $ nbio_read sock chk
     case result of
       Nothing -> do sys_epoll_wait sock EPOLL_READ
                     sock_recv_all sock chk
       Just 0         -> sys_throw "EOF"
       Just (-1)      -> sys_throw "ERROR_SOCK_READ"
       Just (numread) -> 
           if (numread == len)
              then return ()
              else sock_recv_all sock (Chunk fptr (off+numread) (len-numread))


{-- Higher-level wrappers ---------------------------------------------}

{-----------------------------------------------------------------------
Send a Haskell string over a socket.

Exceptions:
  ERROR_SOCK_WRITE
-----------------------------------------------------------------------}
sock_send_string :: SockFD -> String -> CPSMonad ()
sock_send_string sock s = do 
   chk <- sys_nbio $ string_to_chunk s
   sock_send_all sock chk

{-----------------------------------------------------------------------
Receive some data from a socket and return the received data as a
Haskell string.  

On EOF, an empty string is returned.

Exceptions:
  ERROR_SOCK_READ
-----------------------------------------------------------------------}
sock_recv_string :: SockFD -> CPSMonad String
sock_recv_string sock = do
   chk <- sock_recv_any sock
   s <- sys_nbio $ chunk_to_string chk
   return s

{-----------------------------------------------------------------------
Write a String to a socket with a newline.

Exception:
  ERROR_SOCK_WRITE
-----------------------------------------------------------------------}
sock_write_line sock s = sock_send_string sock (s++"\r\n")

{-----------------------------------------------------------------------
Read a line from a socket.  CR/LF are trimmed.

There is a buffer in the input arguments as well as in the output, the
buffer represents the received yet unparsed portion of the input.

Exception:
  ERROR_SOCK_READ
  EOF
-----------------------------------------------------------------------}
sock_read_line:: SockFD -> String -> CPSMonad (String, String)
sock_read_line sock buffer =
  case elemIndex '\n' buffer of
    Just idx -> 
       let line = trim $ take idx buffer
           new_buf = drop (idx+1) buffer
        in return (line, new_buf)
    Nothing -> do s <- sock_recv_string sock
                  if length s == 0 then sys_throw "EOF" else return ()
                  sock_read_line sock (buffer++s)
 where
  trim [] = []
  trim s = if last s == '\r' then take (length s - 1) s else s
