{--
Copyright (c) 2006, Peng Li
              2006, Stephan A. Zdancewic
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:

* Redistributions of source code must retain the above copyright
  notice, this list of conditions and the following disclaimer.

* Redistributions in binary form must reproduce the above copyright
  notice, this list of conditions and the following disclaimer in the
  documentation and/or other materials provided with the distribution.

* Neither the name of the copyright owners nor the names of its
  contributors may be used to endorse or promote products derived from
  this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
--}

module PacketIO 
  ( packet_setup     -- :: IO ()
  , packet_send_ip   -- :: IPMessage -> IO ()
  , packet_recv_ip   -- :: IO IPMessage
  )
where

import System
import System.IO
import System.Posix.IO
import System.Posix.Types

import Foreign
import Foreign.C
import Control.Exception

import TCP.Type.Base
import TCP.Type.Datagram
import TCP.Aux.Misc
import TCP.Aux.Param

packet_setup :: IO ()
packet_setup = do
     init_packet_sender
     init_packet_receiver
     return ()

packet_send_ip :: IPMessage -> IO ()
packet_send_ip (TCPMessage seg) = send_tcp_packet seg
packet_send_ip _ = error "packet_send_ip can only do TCP packets for the moment"

packet_recv_ip :: IO (Maybe IPMessage)
packet_recv_ip = receive_tcp_packet


foreign import ccall "mod_sendpkt.h initrawsocket" 
  init_packet_sender :: IO CInt
foreign import ccall "mod_sendpkt.h getsendhandle" 
  getOutgoingRawHandle :: IO CInt
foreign import ccall "mod_sendpkt.h test_send_handle_writable" 
  test_send_handle_writable :: IO CInt  -- returns 1 if writable, 0 not writable


foreign import ccall safe "mod_recvpkt.h recvpacket"  
  receiveRawTCPPacket :: Ptr Word8->IO CInt
foreign import ccall "mod_recvpkt.h initipqueue" 
  init_packet_receiver :: IO CInt
foreign import ccall "mod_recvpkt.h getrecvhandle" 
  getIncomingRawHandle :: IO CInt
foreign import ccall "mod_recvpkt.h test_recv_handle_readable" 
  test_recv_handle_readable :: IO CInt  -- returns 1 if writable, 0 not writable


------------ Utilities  ------------------------
htons16 :: Word16 -> Word16
htons16 x = rotate x 8
ntohs16 = htons16

htons32 :: Word32 -> Word32
htons32 x = let x1 = x .&. 255
                x2 = (shiftR x 8) .&. 255
                x3 = (shiftR x 16) .&. 255
                x4 = (shiftR x 24) .&. 255 
            in (shiftL x1 24) +
               (shiftL x2 16) + 
               (shiftL x3 8) + x4
ntohs32 = htons32                

{-# INLINE htons16 #-}
{-# INLINE ntohs16 #-}
{-# INLINE htons32 #-}
{-# INLINE ntohs32 #-}

fillArray :: (Storable a) => Ptr a -> Int -> a -> IO ()
fillArray p l thing = fillArray' p l 0 
  where fillArray' p len x =
            if x==len then return ()
                      else do pokeElemOff p x thing
                              fillArray' p len (x+1)

--------- Marshalling -----------------------------
{-# INLINE marshallWord8       #-} 
{-# INLINE marshallWord16      #-}
{-# INLINE marshallWord32      #-}
{-# INLINE parseWord8          #-}
{-# INLINE parseWord16         #-}
{-# INLINE parseWord32         #-}
{-# INLINE marshall_ip_header  #-}
{-# INLINE marshall_tcp_header #-}
{-# INLINE marshall_tcp_bits   #-}
{-# INLINE parse_tcp_bits      #-}

marshallWord8 :: Ptr Word8 -> Int -> Word8 -> IO ()
marshallWord8 = pokeElemOff
marshallWord16 :: Ptr Word8 -> Int -> Word16 -> IO ()
marshallWord16 p offset w = pokeByteOff p offset (htons16 w)
marshallWord32 :: Ptr Word8 -> Int -> Word32 -> IO ()
marshallWord32 p offset w = pokeByteOff p offset (htons32 w)

parseWord8 :: Ptr Word8 -> Int -> IO Word8
parseWord8 = peekElemOff
parseWord16 :: Ptr Word8 -> Int -> IO Word16 
parseWord16 p offset = do r <- peekByteOff p offset; return (ntohs16 r)
parseWord32 :: Ptr Word8 -> Int -> IO Word32
parseWord32 p offset = do r <- peekByteOff p offset; return (ntohs32 r)

marshall_ip_header :: TCPSegment -> Ptr Word8 -> IO ()
marshall_ip_header seg p = do
  marshallWord8  p 0 ((shiftL (4) 4)+(5)) -- headerLen=5
  marshallWord8  p 1 (0)
  marshallWord16 p 2 (0)
  marshallWord16 p 4 (0) 
  marshallWord16 p 6 ((shiftL (2) 13)+(0)) -- DF, don't fragment
  marshallWord8  p 8 (128) -- Time to live
  marshallWord8  p 9 (6)   -- TCP = 6
  marshallWord16 p 10 (0) 
  marshallWord32 p 12 a1
  marshallWord32 p 16 a5
 where (IPAddr a1) = get_ip $ tcp_src seg
       (IPAddr a5) = get_ip $ tcp_dst seg

marshall_tcp_header :: TCPSegment-> Ptr Word8 -> IO ()
marshall_tcp_header seg p = do
  marshallWord16 p 0 $ get_port $ tcp_src seg
  marshallWord16 p 2 $ get_port $ tcp_dst seg
  marshallWord32 p 4 $ seq_val $ tcp_seq seg
  marshallWord32 p 8 $ seq_val $ tcp_ack seg
  marshallWord16 p 12 field1
  marshallWord16 p 14 $ to_Word16 $ tcp_win seg
  marshallWord16 p 16 0
 where field1 = (shiftL (5) 12) -- dataoffset=5
                 + (shiftL (0) 6)  -- ecn=0
                 + (marshall_tcp_bits (tcp_URG seg) (tcp_ACK seg) (tcp_PSH seg) (tcp_RST seg) (tcp_SYN seg) (tcp_FIN seg))

marshall_tcp_bits urg ack psh rst syn fin =
    (if urg then 32 else 0) +
    (if ack then 16 else 0) +
    (if psh then 8 else 0) +
    (if rst then 4 else 0) +
    (if syn then 2 else 0) +
    (if fin then 1 else 0)

parse_tcp_bits x =  (((x .&. 32)>0)
                    ,((x .&. 16)>0) 
                    ,((x .&. 8)>0)
                    ,((x .&. 4)>0)
                    ,((x .&. 2)>0)
                    ,((x .&. 1)>0))

send_tcp_packet :: TCPSegment -> IO ()
send_tcp_packet seg = do
  header@(Buffer fptr _ _ _) <- new_buffer header_len
  withForeignPtr (castForeignPtr fptr)
    (\(orig :: Ptr Word8) -> do
        marshall_ip_header seg orig
        marshall_tcp_header seg (orig `plusPtr` iphlen)
    )
  let packet@(BufferChain lst _ ) = bufferchain_add header d
  
  let iov_count = assert (bufferchain_ok packet) $ length lst
  iov_ptr :: Ptr CChar <- mallocArray (iov_count * 8)
  fill_vec lst iov_ptr (iov_ptr, iov_count)
  free iov_ptr
 where
  iphlen = 20
  tcphlen = 20
  header_len = iphlen + tcphlen
  len = header_len + (bufc_length d)
  d = tcp_data seg

  fill_vec ((Buffer buf_fptr _ off len):rest) cur_vec_ptr iov =
    withForeignPtr buf_fptr (\buf_ptr ->
      let real_ptr = buf_ptr `plusPtr` off 
          real_offset :: Word32 = fromIntegral len
      in 
        do -- write the pointer
           poke (castPtr cur_vec_ptr) real_ptr
           -- write the size      
           poke (castPtr (cur_vec_ptr `plusPtr` 4)) real_offset
           fill_vec rest (cur_vec_ptr `plusPtr` 8) iov
     )
    
  fill_vec [] _  (iov_ptr, iov_count) = 
        sendRawTCPPacket iov_ptr 
                         (fromIntegral iov_count) 
                         (fromIntegral iphlen) 
                         (fromIntegral tcphlen)
                         (fromIntegral len)


foreign import ccall safe "mod_sendpkt.h sendpacket" 
  sendRawTCPPacket :: Ptr CChar -> CInt -> CInt -> CInt -> CInt -> IO CInt

packetBufferSize = 2000

receive_tcp_packet :: IO (Maybe IPMessage)
receive_tcp_packet = do 
  (Buffer fptr size _ _) <- new_buffer packetBufferSize
  withForeignPtr (castForeignPtr fptr) $ \(p :: Ptr Word8) -> do
    res <- receiveRawTCPPacket p
    if (res == 0) then return Nothing else do
    -- parse IP header
    let ipstart = p `plusPtr` (fromIntegral res)
    tmp0 <- parseWord8  ipstart 0
    tmp2 <- parseWord16 ipstart 2
    tmp9 <- parseWord8  ipstart 9
    let iphlen   = (fromIntegral $ tmp0 .&. 15) * 4
        totallen = fromIntegral tmp2
        protocol = tmp9
    if (protocol /= 6) then do { putStrLn "Non-TCP packet received..."; return Nothing } else do
    -- This is a TCP packet
    tmp12 <- parseWord32 ipstart 12
    tmp16 <- parseWord32 ipstart 16
    let source_addr  = IPAddr tmp12
        dest_addr    = IPAddr tmp16
    -- parse TCP header
    let tcpstart = ipstart `plusPtr` (fromIntegral iphlen)
    t0 <- parseWord16  tcpstart 0
    t2 <- parseWord16  tcpstart 2
    t4 <- parseWord32  tcpstart 4
    t8 <- parseWord32  tcpstart 8
    t12 <- parseWord16 tcpstart 12
    t14 <- parseWord16 tcpstart 14
    let dataoffset = shiftR t12 12
        tcphlen = fromIntegral dataoffset*4
        datalen = totallen - iphlen - tcphlen
        d = bufferchain_singleton (Buffer fptr size (fromIntegral res+iphlen+tcphlen) datalen)
        (urg,ack,psh,rst,syn,fin) = parse_tcp_bits (t12 .&. 63)
    return $ Just $ TCPMessage $ TCPSegment
       { tcp_src = TCPAddr ( source_addr, t0  )
       , tcp_dst = TCPAddr ( dest_addr,  t2 )
       , tcp_seq = SeqLocal t4
       , tcp_ack = SeqForeign t8
       , tcp_URG = urg
       , tcp_ACK = ack
       , tcp_PSH = psh
       , tcp_RST = rst
       , tcp_SYN = syn
       , tcp_FIN = fin
       , tcp_win = to_Int $ t14
       , tcp_urp = 0
       , tcp_data = d
       -- option: window scaling
       , tcp_ws      = Nothing
       -- option: max segment size
       , tcp_mss     = Nothing
       -- option: RFC1323
       , tcp_ts      = Nothing
       }
