/*
 * mlfeatureextractor.{cc,hh} -- extract a set of flow features to be used by a
 *                               machine learning classifier
 *
 * Wei Li
 * Marco Canini
 *
 * Copyright (c) 2008-09 by University of Genova - DIST - TNT laboratory
 *
 * Redistribution and use in source and binary forms, with or
 * without modification, are permitted provided that the following
 * conditions are met:
 *
 * * Redistributions of source code must retain the above copyright
 *   notice, this list of conditions and the following disclaimer.
 * * Redistributions in binary form must reproduce the above copyright
 *   notice, this list of conditions and the following disclaimer in
 *   the documentation and/or other materials provided with the distribution.
 * * Neither the name of University of Genova nor the names of its contributors
 *   may be used to endorse or promote products derived from this software
 *   without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
 * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
 * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
 * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA,
 * OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
 * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
 * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 * $Id: mlfeatureextractor.cc 2578 2009-10-05 10:32:45Z marco $
 */

#include <click/config.h>
#include <click/error.hh>
#include <click/straccum.hh>
#include <click/confparse.hh>
#include <click/packet_anno.hh>
#include <clicknet/ip.h>
#include <clicknet/tcp.h>
#include <clicknet/udp.h>

#include <iostream>
#include <stdlib.h>

#include "mlfeatureextractor.hh"
#include "nipquad.hh"

CLICK_DECLS

namespace
{
    int bytes_compare(const void *one, const void *two);
    double bytes_median(uint16_t arr[], uint8_t n);
}


MLFeatureExtractor::MLFeatureExtractor(): flow_cache(0)
{
    log = new FileErrorHandler(stderr, "MLFeatureExtractor::");
}

MLFeatureExtractor::~MLFeatureExtractor()
{
    delete log;
}

int MLFeatureExtractor::configure(Vector<String>& /*conf*/, ErrorHandler* errh)
{
    log->debug("configure");
    flow_cache = FlowCache::upstream_instance(this);

    series_len = 5;
    if (!flow_cache)
        return errh->error("Initialization failure!");

    if (series_len == 0)
        return errh->error("SERIES_LENGTH must be greater then zero");

    flow_cache->register_flow_state_holder(this);

    return 0;
}

Packet* MLFeatureExtractor::simple_action(Packet* p)
{
	//Do not do anything for packets other than tcp
    const click_ip *iph = p->ip_header();
    if (iph->ip_p != IP_PROTO_TCP)
          return p;
        //If tcp, go on:
        
    FlowState* fs = flow_cache->lookup_state<FlowState>(this, p);
    
    if (!fs)
    {
        fs = new FlowState(p, series_len);
        flow_cache->state(this, p) = fs;

        if (!fs)
        {
            click_chatter("out of memory!");
            p->kill();

            return 0;
        }
    }
    fs->handle_packet(p, series_len);

    return p;
}

void MLFeatureExtractor::write_header(DataExport& exporter) const
{
    exporter += column<uint8_t>("DataPktsCln", true);
    exporter += column<uint8_t>("PushPktsCln", true);
    exporter += column<uint8_t>("PushPktsSrv", true);
    exporter += column<uint16_t>("MinSegCln", true);
    exporter += column<uint16_t>("AvgSegSrv", true);
    exporter += column<uint16_t>("IniWinBytesCln", true);
    exporter += column<uint16_t>("IniWinBytesSrv", true);
    exporter += column<uint8_t>("RttSamplesCln", true);
    exporter += column<double>("MedIpDataBytesCln", true);
    exporter += column<double>("VarDataBytesSrv", true);
}


void MLFeatureExtractor::write_flow_state(DataExport& exporter, const BaseFlowState* fs_) const
{
    const FlowState* fs = static_cast<const FlowState*>(fs_);

    exporter << fs->data_pkts_clnt << fs->push_pkts_clnt << fs->push_pkts_serv;
    exporter << fs->min_seg_clnt << fs->avg_seg_serv;
    exporter << fs->ini_win_bytes_clnt << fs->ini_win_bytes_serv;
    exporter << fs->rtt_samps_clnt << fs->ip_data_byte_med_clnt;
    exporter << fs->eth_data_byte_var_serv;
}

void MLFeatureExtractor::flow_over(const FlowCache::Flow&, BaseFlowState* fs_)
{
    FlowState* fs = static_cast<FlowState*>(fs_);
    if (fs->count == series_len)
          return;

    fs->compute();
}

MLFeatureExtractor::FlowState::FlowState(Packet*, size_t series_len): count(0)
{
    series = new SeriesType[series_len];

    avg_eth_data_byte_serv = 0;
    data_pkts_serv = 0;
    pkts_serv = 0;
    pkts_clnt = 0;
    last_retr[0] = -1;
    last_retr[1] = -1;
    ini_win_clnt_closed=0;
    ini_win_serv_closed=0;
    
    data_pkts_clnt = 0;
    push_pkts_serv = 0;
    push_pkts_clnt = 0;
    min_seg_clnt = 0; //actually observed.
    avg_seg_serv = 0; //would first accumulate up at the intermediate steps.
    ini_win_bytes_clnt = 0; //retrans don't count
    ini_win_bytes_serv = 0; //retrans don't count
    rtt_samps_clnt = 0; //ack = prev seq; packet being acked was not a retrans; plus there's not a retrans between the pkt and ack
    ip_data_byte_med_clnt = 0;
    eth_data_byte_var_serv = 0;
}

MLFeatureExtractor::FlowState::~FlowState()
{
    delete[] series;
}

void MLFeatureExtractor::FlowState::handle_packet(Packet* p, size_t series_len)
{
    if (count == series_len)
        return; // nothing to do
    assert(series != 0);

    /* get basic stats */
    const click_ip *iph = p->ip_header();
    const click_tcp *tcph = p->tcp_header();
    
    int dir = PAINT_ANNO(p);
    size_t len = ntohs(iph->ip_len);
    size_t data_len = len - (p->ip_header_length() + tcph->th_off*4);
    
    series[count].dir = dir;
    series[count].len = len;
    series[count].seq = ntohl(tcph->th_seq);
    series[count].ack = ntohl(tcph->th_ack);
    series[count].data_len = data_len;
    
    /* calculate features */
    if (dir == 1) {
        // server -> client
        avg_eth_data_byte_serv += len;
        pkts_serv ++;
    } else {
        // client -> server
        pkts_clnt ++;
    }
  

    if (data_len > 0) {
        if (dir == 0) {
            // client -> server
            data_pkts_clnt ++;
            if (min_seg_clnt == 0 || min_seg_clnt > data_len)
                min_seg_clnt = data_len;
            if (ini_win_clnt_closed == 0) { //should exclude retrans; but has not been implemented.
            	  ini_win_bytes_clnt += data_len;
            }
            if (ini_win_bytes_serv > 0 && ini_win_serv_closed == 0){
            	  ini_win_serv_closed = 1;
            }
        } else {
            // server -> client
            data_pkts_serv ++;
            avg_seg_serv += data_len;
            if (ini_win_serv_closed == 0 ){ //same
            	  ini_win_bytes_serv += data_len;
            }
            if (ini_win_bytes_clnt > 0 && ini_win_clnt_closed == 0){
            	  ini_win_clnt_closed = 1;
            }
        }
    }
    
    if (tcph->th_flags & TH_PUSH)
    {
    	if (dir == 0)
            push_pkts_clnt ++;
        else
            push_pkts_serv ++;
    }    
    
    if (dir == 1 && count > 0 && series[count-1].dir == 0 &&
        ((series[count-1].data_len > 0 &&
        series[count].ack == series[count-1].seq + series[count-1].data_len) ||
        series[count].ack == series[count-1].seq + 1))
    {
        rtt_samps_clnt ++;
    }

    ++count;
    
    if (count == series_len)
        compute();
}


void MLFeatureExtractor::FlowState::compute()
{
    assert(series != 0);
    if (data_pkts_serv > 1)
        avg_seg_serv /= data_pkts_serv;

    if (pkts_serv > 1)
    {
        avg_eth_data_byte_serv /= pkts_serv;
        for (size_t i = 0; i < count; ++i)
        {
            if (series[i].dir == 0)
                continue;
            size_t len = series[i].len - avg_eth_data_byte_serv;
            eth_data_byte_var_serv += len * len;
        }
        eth_data_byte_var_serv /= pkts_serv - 1;
    }

    if (pkts_clnt > 0)
    {
        uint16_t lens[pkts_clnt];
        for (size_t i = 0, j = 0; i < count; ++i)
        {
            if (series[i].dir == 1)
                continue;
            lens[j] = series[i].len;
            ++j;
        };
        ip_data_byte_med_clnt = bytes_median(lens, pkts_clnt);
    }
    
    delete[] series;
    series = 0;
}

namespace
{
    int bytes_compare(const void *one, const void *two)
    {
         return *(uint16_t*) one - *(uint16_t*) two;
    }
    
    double bytes_median(uint16_t arr[], uint8_t n)
    {
    	   if (n==1)
    	       return (double) arr[0];
    	   qsort(arr, n, sizeof(uint16_t), bytes_compare);
    	   if (n%2==1)
    	   	   return (double) arr[n/2];
    	   else
    	   	   return (double) (arr[n/2] + arr[n/2-1])/2;
    }
}

ELEMENT_REQUIRES(FlowCache)
EXPORT_ELEMENT(MLFeatureExtractor)

#include <click/vector.cc>

CLICK_ENDDECLS

