/*
 * mlflowcls.{cc,hh} -- classify TCP/UDP flows according to the class of
 *                      applications that generate them using a
 *                      machine-learning classifier
 * Marco Canini
 *
 * Copyright (c) 2008-09 by University of Genova - DIST - TNT laboratory
 *
 * Redistribution and use in source and binary forms, with or
 * without modification, are permitted provided that the following
 * conditions are met:
 *
 * * Redistributions of source code must retain the above copyright
 *   notice, this list of conditions and the following disclaimer.
 * * Redistributions in binary form must reproduce the above copyright
 *   notice, this list of conditions and the following disclaimer in
 *   the documentation and/or other materials provided with the distribution.
 * * Neither the name of University of Genova nor the names of its contributors
 *   may be used to endorse or promote products derived from this software
 *   without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
 * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
 * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
 * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA,
 * OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
 * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
 * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 * $Id: mlflowcls.cc 2578 2009-10-05 10:32:45Z marco $
 */

#include <click/config.h>
#include <click/error.hh>
#include <click/confparse.hh>

#include <iostream>

#include "mlflowcls.hh"

CLICK_DECLS

MLFlowCls::MLFlowCls()
{
}

MLFlowCls::~MLFlowCls()
{
}

int MLFlowCls::configure(Vector<String>& conf, ErrorHandler* errh)
{
    String filename;
    Element* e1 = 0;
    Element* e2 = 0;

    if (cp_va_kparse(conf, this, errh, 
            "MLFEATUREEXTRACTOR", cpkM, cpElement, &e1,
            "WEKACLS", cpkM, cpElement, &e2,
            cpEnd) < 0)
        return -1;
 
    flow_cache_ = FlowCache::upstream_instance(this);

    if (!flow_cache_)
        return errh->error("Initialization failure!");

    flow_cache_->register_flow_state_holder(this, name());

    if (e1 && !(ml_feat_extr_ = (MLFeatureExtractor*)(e1->cast("MLFeatureExtractor"))))
        return errh->error("MLFEATUREEXTRACTOR must be a MLFeatureExtractor element");
        
    if (e2 && !(cls_ = (WekaCls*)(e2->cast("WekaCls"))))
        return errh->error("WEKACLS must be a WekaCls element");
    
    return 0;
}

Packet* MLFlowCls::simple_action(Packet* p)
{
    FlowState* fs = flow_cache_->lookup_state<FlowState>(this, p);

    if (!fs)
    {
        fs = new FlowState();
        flow_cache_->state(this, p) = fs;

        if (!fs)
        {
            click_chatter("out of memory!");
            p->kill();

            return 0;
        }
    }

    FlowCache::Flow f = flow_cache_->lookup_flow(p);
    try_classify(f, fs);
    
    SET_APPMARK_ANNO(p, fs->get_mark());
    return p;
}

bool MLFlowCls::try_classify(const FlowCache::Flow& flow, FlowState* fs)
{
    /* features */
    const MLFeatureExtractor::FlowState* features;
    features = flow.lookup_state<MLFeatureExtractor::FlowState>(ml_feat_extr_);
    assert(features);
    if (!features->ready())
        return false;
        
    double c;
    double input[12];
    
    input[0] = (double) ntohs(flow.dst_port());
    input[1] = (double) ntohs(flow.src_port());
    input[2] = (double) features->client_data_pkts();
    input[3] = (double) features->client_push_pkts();
    input[4] = (double) features->server_push_pkts();
    input[5] = (double) features->client_min_seg_size();
    input[6] = (double) features->server_avg_seg_size();
    input[7] = (double) features->client_bytes_in_init_win();
    input[8] = (double) features->server_bytes_in_init_win();
    input[9] = (double) features->client_rtt_samples();
    input[10] = (double) features->client_median_bytes();
    input[11] = (double) features->server_var_data_bytes();
    
    /*std::cerr << "MLFlowCls::try_classify ";
    for (int i = 0; i < 12; ++i)
        std::cerr << input[i] << " ";
    std::cerr << std::endl;*/
    
    c = cls_->classify(input);
    /* NOTE: Class 0 and 1 are reserved in AppMarks. Therefore, we sum 2 to the
       classification result since the first valid class in AppMarks is 2. */
    AppMark m((2 + (uint32_t) round(c)) << 24);
    fs->set_mark(m);
    return true;
}

void MLFlowCls::flow_over(const FlowCache::Flow& flow, BaseFlowState* fs_)
{
    //std::cerr << "MLFlowCls::flow_over flow dies with pkts: " << flow.pkts() << std::endl;
    FlowState* fs = static_cast<FlowState*>(fs_);
    if (*fs == AppMarks::empty_mark)
    {
        try_classify(flow, fs);
    }
}

void MLFlowCls::write_header(DataExport& exporter) const
{
    exporter += AppMarks::default_instance()->column("MLFlowClsMark", true);
}

void MLFlowCls::write_flow_state(DataExport& exporter, const BaseFlowState* fs_) const
{
    const FlowState* fs = static_cast<const FlowState*>(fs_);
    exporter << *((const AppMark *) fs);
}

CLICK_ENDDECLS

ELEMENT_REQUIRES(FlowCache)
EXPORT_ELEMENT(MLFlowCls)

#include <click/vector.cc>
